feat: PocketVeto v1.0.0 — initial public release
Self-hosted US Congress monitoring platform with AI policy briefs, bill/member/topic follows, ntfy + RSS + email notifications, alignment scoring, collections, and draft-letter generator. Authored by: Jack Levy
This commit is contained in:
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
497
backend/app/api/admin.py
Normal file
497
backend/app/api/admin.py
Normal file
@@ -0,0 +1,497 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_admin
|
||||
from app.database import get_db
|
||||
from app.models import Bill, BillBrief, BillDocument, Follow
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import UserResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── User Management ───────────────────────────────────────────────────────────
|
||||
|
||||
class UserWithStats(UserResponse):
|
||||
follow_count: int
|
||||
|
||||
|
||||
@router.get("/users", response_model=list[UserWithStats])
|
||||
async def list_users(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""List all users with their follow counts."""
|
||||
users_result = await db.execute(select(User).order_by(User.created_at))
|
||||
users = users_result.scalars().all()
|
||||
|
||||
counts_result = await db.execute(
|
||||
select(Follow.user_id, func.count(Follow.id).label("cnt"))
|
||||
.group_by(Follow.user_id)
|
||||
)
|
||||
counts = {row.user_id: row.cnt for row in counts_result}
|
||||
|
||||
return [
|
||||
UserWithStats(
|
||||
id=u.id,
|
||||
email=u.email,
|
||||
is_admin=u.is_admin,
|
||||
notification_prefs=u.notification_prefs or {},
|
||||
created_at=u.created_at,
|
||||
follow_count=counts.get(u.id, 0),
|
||||
)
|
||||
for u in users
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/users/{user_id}", status_code=204)
|
||||
async def delete_user(
|
||||
user_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Delete a user account (cascades to their follows). Cannot delete yourself."""
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot delete your own account")
|
||||
user = await db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@router.patch("/users/{user_id}/toggle-admin", response_model=UserResponse)
|
||||
async def toggle_admin(
|
||||
user_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Promote or demote a user's admin status."""
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(status_code=400, detail="Cannot change your own admin status")
|
||||
user = await db.get(User, user_id)
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
user.is_admin = not user.is_admin
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
# ── Analysis Stats ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
async def get_stats(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Return analysis pipeline progress counters."""
|
||||
total_bills = (await db.execute(select(func.count()).select_from(Bill))).scalar()
|
||||
docs_fetched = (await db.execute(
|
||||
select(func.count()).select_from(BillDocument).where(BillDocument.raw_text.isnot(None))
|
||||
)).scalar()
|
||||
total_briefs = (await db.execute(select(func.count()).select_from(BillBrief))).scalar()
|
||||
full_briefs = (await db.execute(
|
||||
select(func.count()).select_from(BillBrief).where(BillBrief.brief_type == "full")
|
||||
)).scalar()
|
||||
amendment_briefs = (await db.execute(
|
||||
select(func.count()).select_from(BillBrief).where(BillBrief.brief_type == "amendment")
|
||||
)).scalar()
|
||||
uncited_briefs = (await db.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM bill_briefs
|
||||
WHERE key_points IS NOT NULL
|
||||
AND jsonb_array_length(key_points) > 0
|
||||
AND jsonb_typeof(key_points->0) = 'string'
|
||||
""")
|
||||
)).scalar()
|
||||
# Bills with null sponsor
|
||||
bills_missing_sponsor = (await db.execute(
|
||||
text("SELECT COUNT(*) FROM bills WHERE sponsor_id IS NULL")
|
||||
)).scalar()
|
||||
# Bills with null metadata (introduced_date / chamber / congress_url)
|
||||
bills_missing_metadata = (await db.execute(
|
||||
text("SELECT COUNT(*) FROM bills WHERE introduced_date IS NULL OR chamber IS NULL OR congress_url IS NULL")
|
||||
)).scalar()
|
||||
# Bills with no document record at all (text not yet published on GovInfo)
|
||||
no_text_bills = (await db.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM bills b
|
||||
LEFT JOIN bill_documents bd ON bd.bill_id = b.bill_id
|
||||
WHERE bd.id IS NULL
|
||||
""")
|
||||
)).scalar()
|
||||
# Documents that have text but no brief (LLM not yet run / failed)
|
||||
pending_llm = (await db.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM bill_documents bd
|
||||
LEFT JOIN bill_briefs bb ON bb.document_id = bd.id
|
||||
WHERE bb.id IS NULL AND bd.raw_text IS NOT NULL
|
||||
""")
|
||||
)).scalar()
|
||||
# Bills that have never had their action history fetched
|
||||
bills_missing_actions = (await db.execute(
|
||||
text("SELECT COUNT(*) FROM bills WHERE actions_fetched_at IS NULL")
|
||||
)).scalar()
|
||||
# Cited brief points (objects) that have no label yet
|
||||
unlabeled_briefs = (await db.execute(
|
||||
text("""
|
||||
SELECT COUNT(*) FROM bill_briefs
|
||||
WHERE (
|
||||
key_points IS NOT NULL AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements(key_points) AS p
|
||||
WHERE jsonb_typeof(p) = 'object' AND (p->>'label') IS NULL
|
||||
)
|
||||
) OR (
|
||||
risks IS NOT NULL AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements(risks) AS r
|
||||
WHERE jsonb_typeof(r) = 'object' AND (r->>'label') IS NULL
|
||||
)
|
||||
)
|
||||
""")
|
||||
)).scalar()
|
||||
return {
|
||||
"total_bills": total_bills,
|
||||
"docs_fetched": docs_fetched,
|
||||
"briefs_generated": total_briefs,
|
||||
"full_briefs": full_briefs,
|
||||
"amendment_briefs": amendment_briefs,
|
||||
"uncited_briefs": uncited_briefs,
|
||||
"no_text_bills": no_text_bills,
|
||||
"pending_llm": pending_llm,
|
||||
"bills_missing_sponsor": bills_missing_sponsor,
|
||||
"bills_missing_metadata": bills_missing_metadata,
|
||||
"bills_missing_actions": bills_missing_actions,
|
||||
"unlabeled_briefs": unlabeled_briefs,
|
||||
"remaining": total_bills - total_briefs,
|
||||
}
|
||||
|
||||
|
||||
# ── Celery Tasks ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/backfill-citations")
|
||||
async def backfill_citations(current_user: User = Depends(get_current_admin)):
|
||||
"""Delete pre-citation briefs and re-queue LLM processing using stored document text."""
|
||||
from app.workers.llm_processor import backfill_brief_citations
|
||||
task = backfill_brief_citations.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-sponsors")
|
||||
async def backfill_sponsors(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.congress_poller import backfill_sponsor_ids
|
||||
task = backfill_sponsor_ids.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-poll")
|
||||
async def trigger_poll(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.congress_poller import poll_congress_bills
|
||||
task = poll_congress_bills.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-member-sync")
|
||||
async def trigger_member_sync(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.congress_poller import sync_members
|
||||
task = sync_members.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-fetch-actions")
|
||||
async def trigger_fetch_actions(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.congress_poller import fetch_actions_for_active_bills
|
||||
task = fetch_actions_for_active_bills.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-all-actions")
|
||||
async def backfill_all_actions(current_user: User = Depends(get_current_admin)):
|
||||
"""Queue action fetches for every bill that has never had actions fetched."""
|
||||
from app.workers.congress_poller import backfill_all_bill_actions
|
||||
task = backfill_all_bill_actions.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-metadata")
|
||||
async def backfill_metadata(current_user: User = Depends(get_current_admin)):
|
||||
"""Fill in null introduced_date, congress_url, chamber for existing bills."""
|
||||
from app.workers.congress_poller import backfill_bill_metadata
|
||||
task = backfill_bill_metadata.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-labels")
|
||||
async def backfill_labels(current_user: User = Depends(get_current_admin)):
|
||||
"""Classify existing cited brief points as fact or inference without re-generating briefs."""
|
||||
from app.workers.llm_processor import backfill_brief_labels
|
||||
task = backfill_brief_labels.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-cosponsors")
|
||||
async def backfill_cosponsors(current_user: User = Depends(get_current_admin)):
|
||||
"""Fetch co-sponsor data from Congress.gov for all bills that haven't been fetched yet."""
|
||||
from app.workers.bill_classifier import backfill_all_bill_cosponsors
|
||||
task = backfill_all_bill_cosponsors.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/backfill-categories")
|
||||
async def backfill_categories(current_user: User = Depends(get_current_admin)):
|
||||
"""Classify all bills with text but no category as substantive/commemorative/administrative."""
|
||||
from app.workers.bill_classifier import backfill_bill_categories
|
||||
task = backfill_bill_categories.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/calculate-effectiveness")
|
||||
async def calculate_effectiveness(current_user: User = Depends(get_current_admin)):
|
||||
"""Recalculate member effectiveness scores and percentiles now."""
|
||||
from app.workers.bill_classifier import calculate_effectiveness_scores
|
||||
task = calculate_effectiveness_scores.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/resume-analysis")
|
||||
async def resume_analysis(current_user: User = Depends(get_current_admin)):
|
||||
"""Re-queue LLM processing for docs with no brief, and document fetching for bills with no doc."""
|
||||
from app.workers.llm_processor import resume_pending_analysis
|
||||
task = resume_pending_analysis.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-weekly-digest")
|
||||
async def trigger_weekly_digest(current_user: User = Depends(get_current_admin)):
|
||||
"""Send the weekly bill activity summary to all eligible users now."""
|
||||
from app.workers.notification_dispatcher import send_weekly_digest
|
||||
task = send_weekly_digest.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-trend-scores")
|
||||
async def trigger_trend_scores(current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.trend_scorer import calculate_all_trend_scores
|
||||
task = calculate_all_trend_scores.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/bills/{bill_id}/reprocess")
|
||||
async def reprocess_bill(bill_id: str, current_user: User = Depends(get_current_admin)):
|
||||
"""Queue document and action fetches for a specific bill. Useful for debugging."""
|
||||
from app.workers.document_fetcher import fetch_bill_documents
|
||||
from app.workers.congress_poller import fetch_bill_actions
|
||||
doc_task = fetch_bill_documents.delay(bill_id)
|
||||
actions_task = fetch_bill_actions.delay(bill_id)
|
||||
return {"task_ids": {"documents": doc_task.id, "actions": actions_task.id}}
|
||||
|
||||
|
||||
@router.get("/newsapi-quota")
|
||||
async def get_newsapi_quota(current_user: User = Depends(get_current_admin)):
|
||||
"""Return today's remaining NewsAPI daily quota (calls used vs. 100/day limit)."""
|
||||
from app.services.news_service import get_newsapi_quota_remaining
|
||||
import asyncio
|
||||
remaining = await asyncio.to_thread(get_newsapi_quota_remaining)
|
||||
return {"remaining": remaining, "limit": 100}
|
||||
|
||||
|
||||
@router.post("/clear-gnews-cache")
|
||||
async def clear_gnews_cache_endpoint(current_user: User = Depends(get_current_admin)):
|
||||
"""Flush the Google News RSS Redis cache so fresh data is fetched on next run."""
|
||||
from app.services.news_service import clear_gnews_cache
|
||||
import asyncio
|
||||
cleared = await asyncio.to_thread(clear_gnews_cache)
|
||||
return {"cleared": cleared}
|
||||
|
||||
|
||||
@router.post("/submit-llm-batch")
|
||||
async def submit_llm_batch_endpoint(current_user: User = Depends(get_current_admin)):
|
||||
"""Submit all unbriefed documents to the Batch API (OpenAI/Anthropic only)."""
|
||||
from app.workers.llm_batch_processor import submit_llm_batch
|
||||
task = submit_llm_batch.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.get("/llm-batch-status")
|
||||
async def get_llm_batch_status(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Return the current batch job state, or no_active_batch if none."""
|
||||
import json
|
||||
from app.models.setting import AppSetting
|
||||
row = await db.get(AppSetting, "llm_active_batch")
|
||||
if not row:
|
||||
return {"status": "no_active_batch"}
|
||||
try:
|
||||
return json.loads(row.value)
|
||||
except Exception:
|
||||
return {"status": "unknown"}
|
||||
|
||||
|
||||
@router.get("/api-health")
|
||||
async def api_health(current_user: User = Depends(get_current_admin)):
|
||||
"""Test each external API and return status + latency for each."""
|
||||
import asyncio
|
||||
results = await asyncio.gather(
|
||||
asyncio.to_thread(_test_congress),
|
||||
asyncio.to_thread(_test_govinfo),
|
||||
asyncio.to_thread(_test_newsapi),
|
||||
asyncio.to_thread(_test_gnews),
|
||||
asyncio.to_thread(_test_rep_lookup),
|
||||
return_exceptions=True,
|
||||
)
|
||||
keys = ["congress_gov", "govinfo", "newsapi", "google_news", "rep_lookup"]
|
||||
return {
|
||||
k: r if isinstance(r, dict) else {"status": "error", "detail": str(r)}
|
||||
for k, r in zip(keys, results)
|
||||
}
|
||||
|
||||
|
||||
def _timed(fn):
|
||||
"""Run fn(), return its dict merged with latency_ms."""
|
||||
import time as _time
|
||||
t0 = _time.perf_counter()
|
||||
result = fn()
|
||||
result["latency_ms"] = round((_time.perf_counter() - t0) * 1000)
|
||||
return result
|
||||
|
||||
|
||||
def _test_congress() -> dict:
|
||||
from app.config import settings
|
||||
from app.services import congress_api
|
||||
if not settings.DATA_GOV_API_KEY:
|
||||
return {"status": "error", "detail": "DATA_GOV_API_KEY not configured"}
|
||||
def _call():
|
||||
data = congress_api.get_bills(119, limit=1)
|
||||
count = data.get("pagination", {}).get("count") or len(data.get("bills", []))
|
||||
return {"status": "ok", "detail": f"{count:,} bills available in 119th Congress"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
def _test_govinfo() -> dict:
|
||||
from app.config import settings
|
||||
import requests as req
|
||||
if not settings.DATA_GOV_API_KEY:
|
||||
return {"status": "error", "detail": "DATA_GOV_API_KEY not configured"}
|
||||
def _call():
|
||||
# /collections lists all available collections — simple health check endpoint
|
||||
resp = req.get(
|
||||
"https://api.govinfo.gov/collections",
|
||||
params={"api_key": settings.DATA_GOV_API_KEY},
|
||||
timeout=15,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
collections = data.get("collections", [])
|
||||
bills_col = next((c for c in collections if c.get("collectionCode") == "BILLS"), None)
|
||||
if bills_col:
|
||||
count = bills_col.get("packageCount", "?")
|
||||
return {"status": "ok", "detail": f"BILLS collection: {count:,} packages" if isinstance(count, int) else "GovInfo reachable, BILLS collection found"}
|
||||
return {"status": "ok", "detail": f"GovInfo reachable — {len(collections)} collections available"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
def _test_newsapi() -> dict:
|
||||
from app.config import settings
|
||||
import requests as req
|
||||
if not settings.NEWSAPI_KEY:
|
||||
return {"status": "skipped", "detail": "NEWSAPI_KEY not configured"}
|
||||
def _call():
|
||||
resp = req.get(
|
||||
"https://newsapi.org/v2/top-headlines",
|
||||
params={"country": "us", "pageSize": 1, "apiKey": settings.NEWSAPI_KEY},
|
||||
timeout=10,
|
||||
)
|
||||
data = resp.json()
|
||||
if data.get("status") != "ok":
|
||||
return {"status": "error", "detail": data.get("message", "Unknown error")}
|
||||
return {"status": "ok", "detail": f"{data.get('totalResults', 0):,} headlines available"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
def _test_gnews() -> dict:
|
||||
import requests as req
|
||||
def _call():
|
||||
resp = req.get(
|
||||
"https://news.google.com/rss/search",
|
||||
params={"q": "congress", "hl": "en-US", "gl": "US", "ceid": "US:en"},
|
||||
timeout=10,
|
||||
headers={"User-Agent": "Mozilla/5.0"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
item_count = resp.text.count("<item>")
|
||||
return {"status": "ok", "detail": f"{item_count} items in test RSS feed"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
def _test_rep_lookup() -> dict:
|
||||
import re as _re
|
||||
import requests as req
|
||||
def _call():
|
||||
# Step 1: Nominatim ZIP → lat/lng
|
||||
r1 = req.get(
|
||||
"https://nominatim.openstreetmap.org/search",
|
||||
params={"postalcode": "20001", "country": "US", "format": "json", "limit": "1"},
|
||||
headers={"User-Agent": "PocketVeto/1.0"},
|
||||
timeout=10,
|
||||
)
|
||||
r1.raise_for_status()
|
||||
places = r1.json()
|
||||
if not places:
|
||||
return {"status": "error", "detail": "Nominatim: no result for test ZIP 20001"}
|
||||
lat, lng = places[0]["lat"], places[0]["lon"]
|
||||
half = 0.5
|
||||
# Step 2: TIGERweb identify → congressional district
|
||||
r2 = req.get(
|
||||
"https://tigerweb.geo.census.gov/arcgis/rest/services/TIGERweb/Legislative/MapServer/identify",
|
||||
params={
|
||||
"f": "json",
|
||||
"geometry": f"{lng},{lat}",
|
||||
"geometryType": "esriGeometryPoint",
|
||||
"sr": "4326",
|
||||
"layers": "all",
|
||||
"tolerance": "2",
|
||||
"mapExtent": f"{float(lng)-half},{float(lat)-half},{float(lng)+half},{float(lat)+half}",
|
||||
"imageDisplay": "100,100,96",
|
||||
},
|
||||
timeout=15,
|
||||
)
|
||||
r2.raise_for_status()
|
||||
results = r2.json().get("results", [])
|
||||
for item in results:
|
||||
attrs = item.get("attributes", {})
|
||||
cd_field = next((k for k in attrs if _re.match(r"CD\d+FP$", k)), None)
|
||||
if cd_field:
|
||||
district = str(int(str(attrs[cd_field]))) if str(attrs[cd_field]).strip("0") else "At-large"
|
||||
return {"status": "ok", "detail": f"Nominatim + TIGERweb reachable — district {district} found for ZIP 20001"}
|
||||
layers = [r.get("layerName") for r in results]
|
||||
return {"status": "error", "detail": f"Reachable but no CD field found. Layers: {layers}"}
|
||||
try:
|
||||
return _timed(_call)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
@router.get("/task-status/{task_id}")
|
||||
async def get_task_status(task_id: str, current_user: User = Depends(get_current_admin)):
|
||||
from app.workers.celery_app import celery_app
|
||||
result = celery_app.AsyncResult(task_id)
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": result.status,
|
||||
"result": result.result if result.ready() else None,
|
||||
}
|
||||
161
backend/app/api/alignment.py
Normal file
161
backend/app/api/alignment.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Representation Alignment API.
|
||||
|
||||
Returns how well each followed member's voting record aligns with the
|
||||
current user's bill stances (pocket_veto / pocket_boost).
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models import Follow, Member
|
||||
from app.models.user import User
|
||||
from app.models.vote import BillVote, MemberVotePosition
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_alignment(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Cross-reference the user's stanced bill follows with how their
|
||||
followed members voted on those same bills.
|
||||
|
||||
pocket_boost + Yea → aligned
|
||||
pocket_veto + Nay → aligned
|
||||
All other combinations with an actual Yea/Nay vote → opposed
|
||||
Not Voting / Present → excluded from tally
|
||||
"""
|
||||
# 1. Bill follows with a stance
|
||||
bill_follows_result = await db.execute(
|
||||
select(Follow).where(
|
||||
Follow.user_id == current_user.id,
|
||||
Follow.follow_type == "bill",
|
||||
Follow.follow_mode.in_(["pocket_veto", "pocket_boost"]),
|
||||
)
|
||||
)
|
||||
bill_follows = bill_follows_result.scalars().all()
|
||||
|
||||
if not bill_follows:
|
||||
return {
|
||||
"members": [],
|
||||
"total_bills_with_stance": 0,
|
||||
"total_bills_with_votes": 0,
|
||||
}
|
||||
|
||||
stance_map = {f.follow_value: f.follow_mode for f in bill_follows}
|
||||
|
||||
# 2. Followed members
|
||||
member_follows_result = await db.execute(
|
||||
select(Follow).where(
|
||||
Follow.user_id == current_user.id,
|
||||
Follow.follow_type == "member",
|
||||
)
|
||||
)
|
||||
member_follows = member_follows_result.scalars().all()
|
||||
followed_member_ids = {f.follow_value for f in member_follows}
|
||||
|
||||
if not followed_member_ids:
|
||||
return {
|
||||
"members": [],
|
||||
"total_bills_with_stance": len(stance_map),
|
||||
"total_bills_with_votes": 0,
|
||||
}
|
||||
|
||||
# 3. Bulk fetch votes for all stanced bills
|
||||
bill_ids = list(stance_map.keys())
|
||||
votes_result = await db.execute(
|
||||
select(BillVote).where(BillVote.bill_id.in_(bill_ids))
|
||||
)
|
||||
votes = votes_result.scalars().all()
|
||||
|
||||
if not votes:
|
||||
return {
|
||||
"members": [],
|
||||
"total_bills_with_stance": len(stance_map),
|
||||
"total_bills_with_votes": 0,
|
||||
}
|
||||
|
||||
vote_ids = [v.id for v in votes]
|
||||
bill_id_by_vote = {v.id: v.bill_id for v in votes}
|
||||
bills_with_votes = len({v.bill_id for v in votes})
|
||||
|
||||
# 4. Bulk fetch positions for followed members on those votes
|
||||
positions_result = await db.execute(
|
||||
select(MemberVotePosition).where(
|
||||
MemberVotePosition.vote_id.in_(vote_ids),
|
||||
MemberVotePosition.bioguide_id.in_(followed_member_ids),
|
||||
)
|
||||
)
|
||||
positions = positions_result.scalars().all()
|
||||
|
||||
# 5. Aggregate per member
|
||||
tally: dict[str, dict] = defaultdict(lambda: {"aligned": 0, "opposed": 0})
|
||||
|
||||
for pos in positions:
|
||||
if pos.position not in ("Yea", "Nay"):
|
||||
# Skip Not Voting / Present — not a real position signal
|
||||
continue
|
||||
bill_id = bill_id_by_vote.get(pos.vote_id)
|
||||
if not bill_id:
|
||||
continue
|
||||
stance = stance_map.get(bill_id)
|
||||
is_aligned = (
|
||||
(stance == "pocket_boost" and pos.position == "Yea") or
|
||||
(stance == "pocket_veto" and pos.position == "Nay")
|
||||
)
|
||||
if is_aligned:
|
||||
tally[pos.bioguide_id]["aligned"] += 1
|
||||
else:
|
||||
tally[pos.bioguide_id]["opposed"] += 1
|
||||
|
||||
if not tally:
|
||||
return {
|
||||
"members": [],
|
||||
"total_bills_with_stance": len(stance_map),
|
||||
"total_bills_with_votes": bills_with_votes,
|
||||
}
|
||||
|
||||
# 6. Load member details
|
||||
member_ids = list(tally.keys())
|
||||
members_result = await db.execute(
|
||||
select(Member).where(Member.bioguide_id.in_(member_ids))
|
||||
)
|
||||
members = members_result.scalars().all()
|
||||
member_map = {m.bioguide_id: m for m in members}
|
||||
|
||||
# 7. Build response
|
||||
result = []
|
||||
for bioguide_id, counts in tally.items():
|
||||
m = member_map.get(bioguide_id)
|
||||
aligned = counts["aligned"]
|
||||
opposed = counts["opposed"]
|
||||
total = aligned + opposed
|
||||
result.append({
|
||||
"bioguide_id": bioguide_id,
|
||||
"name": m.name if m else bioguide_id,
|
||||
"party": m.party if m else None,
|
||||
"state": m.state if m else None,
|
||||
"chamber": m.chamber if m else None,
|
||||
"photo_url": m.photo_url if m else None,
|
||||
"effectiveness_percentile": m.effectiveness_percentile if m else None,
|
||||
"aligned": aligned,
|
||||
"opposed": opposed,
|
||||
"total": total,
|
||||
"alignment_pct": round(aligned / total * 100, 1) if total > 0 else None,
|
||||
})
|
||||
|
||||
result.sort(key=lambda x: (x["alignment_pct"] is None, -(x["alignment_pct"] or 0)))
|
||||
|
||||
return {
|
||||
"members": result,
|
||||
"total_bills_with_stance": len(stance_map),
|
||||
"total_bills_with_votes": bills_with_votes,
|
||||
}
|
||||
58
backend/app/api/auth.py
Normal file
58
backend/app/api/auth.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.core.security import create_access_token, hash_password, verify_password
|
||||
from app.database import get_db
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import TokenResponse, UserCreate, UserResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/register", response_model=TokenResponse, status_code=201)
|
||||
async def register(body: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
if len(body.password) < 8:
|
||||
raise HTTPException(status_code=400, detail="Password must be at least 8 characters")
|
||||
if "@" not in body.email:
|
||||
raise HTTPException(status_code=400, detail="Invalid email address")
|
||||
|
||||
# Check for duplicate email
|
||||
existing = await db.execute(select(User).where(User.email == body.email.lower()))
|
||||
if existing.scalar_one_or_none():
|
||||
raise HTTPException(status_code=409, detail="Email already registered")
|
||||
|
||||
# First registered user becomes admin
|
||||
count_result = await db.execute(select(func.count()).select_from(User))
|
||||
is_first_user = count_result.scalar() == 0
|
||||
|
||||
user = User(
|
||||
email=body.email.lower(),
|
||||
hashed_password=hash_password(body.password),
|
||||
is_admin=is_first_user,
|
||||
)
|
||||
db.add(user)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return TokenResponse(access_token=create_access_token(user.id), user=user)
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: UserCreate, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(User).where(User.email == body.email.lower()))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user or not verify_password(body.password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect email or password",
|
||||
)
|
||||
|
||||
return TokenResponse(access_token=create_access_token(user.id), user=user)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def me(current_user: User = Depends(get_current_user)):
|
||||
return current_user
|
||||
277
backend/app/api/bills.py
Normal file
277
backend/app/api/bills.py
Normal file
@@ -0,0 +1,277 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import desc, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, BillAction, BillBrief, BillDocument, NewsArticle, TrendScore
|
||||
from app.schemas.schemas import (
|
||||
BillDetailSchema,
|
||||
BillSchema,
|
||||
BillActionSchema,
|
||||
BillVoteSchema,
|
||||
NewsArticleSchema,
|
||||
PaginatedResponse,
|
||||
TrendScoreSchema,
|
||||
)
|
||||
|
||||
_BILL_TYPE_LABELS: dict[str, str] = {
|
||||
"hr": "H.R.",
|
||||
"s": "S.",
|
||||
"hjres": "H.J.Res.",
|
||||
"sjres": "S.J.Res.",
|
||||
"hconres": "H.Con.Res.",
|
||||
"sconres": "S.Con.Res.",
|
||||
"hres": "H.Res.",
|
||||
"sres": "S.Res.",
|
||||
}
|
||||
|
||||
|
||||
class DraftLetterRequest(BaseModel):
|
||||
stance: Literal["yes", "no"]
|
||||
recipient: Literal["house", "senate"]
|
||||
tone: Literal["short", "polite", "firm"]
|
||||
selected_points: list[str]
|
||||
include_citations: bool = True
|
||||
zip_code: str | None = None # not stored, not logged
|
||||
rep_name: str | None = None # not stored, not logged
|
||||
|
||||
|
||||
class DraftLetterResponse(BaseModel):
|
||||
draft: str
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[BillSchema])
|
||||
async def list_bills(
|
||||
chamber: Optional[str] = Query(None),
|
||||
topic: Optional[str] = Query(None),
|
||||
sponsor_id: Optional[str] = Query(None),
|
||||
q: Optional[str] = Query(None),
|
||||
has_document: Optional[bool] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
sort: str = Query("latest_action_date"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = (
|
||||
select(Bill)
|
||||
.options(
|
||||
selectinload(Bill.sponsor),
|
||||
selectinload(Bill.briefs),
|
||||
selectinload(Bill.trend_scores),
|
||||
)
|
||||
)
|
||||
|
||||
if chamber:
|
||||
query = query.where(Bill.chamber == chamber)
|
||||
if sponsor_id:
|
||||
query = query.where(Bill.sponsor_id == sponsor_id)
|
||||
if topic:
|
||||
query = query.join(BillBrief, Bill.bill_id == BillBrief.bill_id).where(
|
||||
BillBrief.topic_tags.contains([topic])
|
||||
)
|
||||
if q:
|
||||
query = query.where(
|
||||
or_(
|
||||
Bill.bill_id.ilike(f"%{q}%"),
|
||||
Bill.title.ilike(f"%{q}%"),
|
||||
Bill.short_title.ilike(f"%{q}%"),
|
||||
)
|
||||
)
|
||||
if has_document is True:
|
||||
doc_subq = select(BillDocument.bill_id).where(BillDocument.bill_id == Bill.bill_id).exists()
|
||||
query = query.where(doc_subq)
|
||||
elif has_document is False:
|
||||
doc_subq = select(BillDocument.bill_id).where(BillDocument.bill_id == Bill.bill_id).exists()
|
||||
query = query.where(~doc_subq)
|
||||
|
||||
# Count total
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = await db.scalar(count_query) or 0
|
||||
|
||||
# Sort
|
||||
sort_col = getattr(Bill, sort, Bill.latest_action_date)
|
||||
query = query.order_by(desc(sort_col)).offset((page - 1) * per_page).limit(per_page)
|
||||
|
||||
result = await db.execute(query)
|
||||
bills = result.scalars().unique().all()
|
||||
|
||||
# Single batch query: which of these bills have at least one document?
|
||||
bill_ids = [b.bill_id for b in bills]
|
||||
doc_result = await db.execute(
|
||||
select(BillDocument.bill_id).where(BillDocument.bill_id.in_(bill_ids)).distinct()
|
||||
)
|
||||
bills_with_docs = {row[0] for row in doc_result}
|
||||
|
||||
# Attach latest brief, trend, and has_document to each bill
|
||||
items = []
|
||||
for bill in bills:
|
||||
bill_dict = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
bill_dict.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
bill_dict.latest_trend = bill.trend_scores[0]
|
||||
bill_dict.has_document = bill.bill_id in bills_with_docs
|
||||
items.append(bill_dict)
|
||||
|
||||
return PaginatedResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=max(1, (total + per_page - 1) // per_page),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{bill_id}", response_model=BillDetailSchema)
|
||||
async def get_bill(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(
|
||||
selectinload(Bill.sponsor),
|
||||
selectinload(Bill.actions),
|
||||
selectinload(Bill.briefs),
|
||||
selectinload(Bill.news_articles),
|
||||
selectinload(Bill.trend_scores),
|
||||
)
|
||||
.where(Bill.bill_id == bill_id)
|
||||
)
|
||||
bill = result.scalar_one_or_none()
|
||||
if not bill:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail="Bill not found")
|
||||
|
||||
detail = BillDetailSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
detail.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
detail.latest_trend = bill.trend_scores[0]
|
||||
doc_exists = await db.scalar(
|
||||
select(func.count()).select_from(BillDocument).where(BillDocument.bill_id == bill_id)
|
||||
)
|
||||
detail.has_document = bool(doc_exists)
|
||||
|
||||
# Trigger a background news refresh if no articles are stored but trend
|
||||
# data shows there are gnews results out there waiting to be fetched.
|
||||
latest_trend = bill.trend_scores[0] if bill.trend_scores else None
|
||||
has_gnews = latest_trend and (latest_trend.gnews_count or 0) > 0
|
||||
if not bill.news_articles and has_gnews:
|
||||
try:
|
||||
from app.workers.news_fetcher import fetch_news_for_bill
|
||||
fetch_news_for_bill.delay(bill_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
@router.get("/{bill_id}/actions", response_model=list[BillActionSchema])
|
||||
async def get_bill_actions(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(BillAction)
|
||||
.where(BillAction.bill_id == bill_id)
|
||||
.order_by(desc(BillAction.action_date))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bill_id}/news", response_model=list[NewsArticleSchema])
|
||||
async def get_bill_news(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(NewsArticle)
|
||||
.where(NewsArticle.bill_id == bill_id)
|
||||
.order_by(desc(NewsArticle.published_at))
|
||||
.limit(20)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bill_id}/trend", response_model=list[TrendScoreSchema])
|
||||
async def get_bill_trend(bill_id: str, days: int = Query(30, ge=7, le=365), db: AsyncSession = Depends(get_db)):
|
||||
from datetime import date, timedelta
|
||||
cutoff = date.today() - timedelta(days=days)
|
||||
result = await db.execute(
|
||||
select(TrendScore)
|
||||
.where(TrendScore.bill_id == bill_id, TrendScore.score_date >= cutoff)
|
||||
.order_by(TrendScore.score_date)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bill_id}/votes", response_model=list[BillVoteSchema])
|
||||
async def get_bill_votes_endpoint(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
from app.models.vote import BillVote
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
result = await db.execute(
|
||||
select(BillVote)
|
||||
.where(BillVote.bill_id == bill_id)
|
||||
.options(selectinload(BillVote.positions))
|
||||
.order_by(desc(BillVote.vote_date))
|
||||
)
|
||||
votes = result.scalars().unique().all()
|
||||
|
||||
# Trigger background fetch if no votes are stored yet
|
||||
if not votes:
|
||||
bill = await db.get(Bill, bill_id)
|
||||
if bill:
|
||||
try:
|
||||
from app.workers.vote_fetcher import fetch_bill_votes
|
||||
fetch_bill_votes.delay(bill_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return votes
|
||||
|
||||
|
||||
@router.post("/{bill_id}/draft-letter", response_model=DraftLetterResponse)
|
||||
async def generate_letter(bill_id: str, body: DraftLetterRequest, db: AsyncSession = Depends(get_db)):
|
||||
from app.models.setting import AppSetting
|
||||
from app.services.llm_service import generate_draft_letter
|
||||
|
||||
bill = await db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
raise HTTPException(status_code=404, detail="Bill not found")
|
||||
|
||||
if not body.selected_points:
|
||||
raise HTTPException(status_code=422, detail="At least one point must be selected")
|
||||
|
||||
prov_row = await db.get(AppSetting, "llm_provider")
|
||||
model_row = await db.get(AppSetting, "llm_model")
|
||||
llm_provider_override = prov_row.value if prov_row else None
|
||||
llm_model_override = model_row.value if model_row else None
|
||||
|
||||
type_label = _BILL_TYPE_LABELS.get((bill.bill_type or "").lower(), (bill.bill_type or "").upper())
|
||||
bill_label = f"{type_label} {bill.bill_number}"
|
||||
|
||||
try:
|
||||
draft = generate_draft_letter(
|
||||
bill_label=bill_label,
|
||||
bill_title=bill.short_title or bill.title or bill_label,
|
||||
stance=body.stance,
|
||||
recipient=body.recipient,
|
||||
tone=body.tone,
|
||||
selected_points=body.selected_points,
|
||||
include_citations=body.include_citations,
|
||||
zip_code=body.zip_code,
|
||||
rep_name=body.rep_name,
|
||||
llm_provider=llm_provider_override,
|
||||
llm_model=llm_model_override,
|
||||
)
|
||||
except Exception as exc:
|
||||
msg = str(exc)
|
||||
if "insufficient_quota" in msg or "quota" in msg.lower():
|
||||
detail = "LLM quota exceeded. Check your API key billing."
|
||||
elif "rate_limit" in msg.lower() or "429" in msg:
|
||||
detail = "LLM rate limit hit. Wait a moment and try again."
|
||||
elif "auth" in msg.lower() or "401" in msg or "403" in msg:
|
||||
detail = "LLM authentication failed. Check your API key."
|
||||
else:
|
||||
detail = f"LLM error: {msg[:200]}"
|
||||
raise HTTPException(status_code=502, detail=detail)
|
||||
return {"draft": draft}
|
||||
319
backend/app/api/collections.py
Normal file
319
backend/app/api/collections.py
Normal file
@@ -0,0 +1,319 @@
|
||||
"""
|
||||
Collections API — named, curated groups of bills with share links.
|
||||
"""
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.bill import Bill, BillDocument
|
||||
from app.models.collection import Collection, CollectionBill
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import (
|
||||
BillSchema,
|
||||
CollectionCreate,
|
||||
CollectionDetailSchema,
|
||||
CollectionSchema,
|
||||
CollectionUpdate,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _slugify(text: str) -> str:
|
||||
text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode()
|
||||
text = re.sub(r"[^\w\s-]", "", text.lower())
|
||||
return re.sub(r"[-\s]+", "-", text).strip("-")
|
||||
|
||||
|
||||
async def _unique_slug(db: AsyncSession, user_id: int, name: str, exclude_id: int | None = None) -> str:
|
||||
base = _slugify(name) or "collection"
|
||||
slug = base
|
||||
counter = 2
|
||||
while True:
|
||||
q = select(Collection).where(Collection.user_id == user_id, Collection.slug == slug)
|
||||
if exclude_id is not None:
|
||||
q = q.where(Collection.id != exclude_id)
|
||||
existing = (await db.execute(q)).scalar_one_or_none()
|
||||
if not existing:
|
||||
return slug
|
||||
slug = f"{base}-{counter}"
|
||||
counter += 1
|
||||
|
||||
|
||||
def _to_schema(collection: Collection) -> CollectionSchema:
|
||||
return CollectionSchema(
|
||||
id=collection.id,
|
||||
name=collection.name,
|
||||
slug=collection.slug,
|
||||
is_public=collection.is_public,
|
||||
share_token=collection.share_token,
|
||||
bill_count=len(collection.collection_bills),
|
||||
created_at=collection.created_at,
|
||||
)
|
||||
|
||||
|
||||
async def _detail_schema(db: AsyncSession, collection: Collection) -> CollectionDetailSchema:
|
||||
"""Build CollectionDetailSchema with bills (including has_document)."""
|
||||
cb_list = collection.collection_bills
|
||||
bills = [cb.bill for cb in cb_list]
|
||||
|
||||
bill_ids = [b.bill_id for b in bills]
|
||||
if bill_ids:
|
||||
doc_result = await db.execute(
|
||||
select(BillDocument.bill_id).where(BillDocument.bill_id.in_(bill_ids)).distinct()
|
||||
)
|
||||
bills_with_docs = {row[0] for row in doc_result}
|
||||
else:
|
||||
bills_with_docs = set()
|
||||
|
||||
bill_schemas = []
|
||||
for bill in bills:
|
||||
bs = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
bs.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
bs.latest_trend = bill.trend_scores[0]
|
||||
bs.has_document = bill.bill_id in bills_with_docs
|
||||
bill_schemas.append(bs)
|
||||
|
||||
return CollectionDetailSchema(
|
||||
id=collection.id,
|
||||
name=collection.name,
|
||||
slug=collection.slug,
|
||||
is_public=collection.is_public,
|
||||
share_token=collection.share_token,
|
||||
bill_count=len(cb_list),
|
||||
created_at=collection.created_at,
|
||||
bills=bill_schemas,
|
||||
)
|
||||
|
||||
|
||||
async def _load_collection(db: AsyncSession, collection_id: int) -> Collection:
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.briefs),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.trend_scores),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.sponsor),
|
||||
)
|
||||
.where(Collection.id == collection_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
# ── List ──────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("", response_model=list[CollectionSchema])
|
||||
async def list_collections(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(selectinload(Collection.collection_bills))
|
||||
.where(Collection.user_id == current_user.id)
|
||||
.order_by(Collection.created_at.desc())
|
||||
)
|
||||
collections = result.scalars().unique().all()
|
||||
return [_to_schema(c) for c in collections]
|
||||
|
||||
|
||||
# ── Create ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("", response_model=CollectionSchema, status_code=201)
|
||||
async def create_collection(
|
||||
body: CollectionCreate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
name = body.name.strip()
|
||||
if not 1 <= len(name) <= 100:
|
||||
raise HTTPException(status_code=422, detail="name must be 1–100 characters")
|
||||
|
||||
slug = await _unique_slug(db, current_user.id, name)
|
||||
collection = Collection(
|
||||
user_id=current_user.id,
|
||||
name=name,
|
||||
slug=slug,
|
||||
is_public=body.is_public,
|
||||
)
|
||||
db.add(collection)
|
||||
await db.flush()
|
||||
await db.execute(select(Collection).where(Collection.id == collection.id)) # ensure loaded
|
||||
await db.commit()
|
||||
await db.refresh(collection)
|
||||
|
||||
# Load collection_bills for bill_count
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(selectinload(Collection.collection_bills))
|
||||
.where(Collection.id == collection.id)
|
||||
)
|
||||
collection = result.scalar_one()
|
||||
return _to_schema(collection)
|
||||
|
||||
|
||||
# ── Share (public — no auth) ──────────────────────────────────────────────────
|
||||
|
||||
@router.get("/share/{share_token}", response_model=CollectionDetailSchema)
|
||||
async def get_collection_by_share_token(
|
||||
share_token: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.briefs),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.trend_scores),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.sponsor),
|
||||
)
|
||||
.where(Collection.share_token == share_token)
|
||||
)
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
return await _detail_schema(db, collection)
|
||||
|
||||
|
||||
# ── Get (owner) ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/{collection_id}", response_model=CollectionDetailSchema)
|
||||
async def get_collection(
|
||||
collection_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
collection = await _load_collection(db, collection_id)
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
return await _detail_schema(db, collection)
|
||||
|
||||
|
||||
# ── Update ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.patch("/{collection_id}", response_model=CollectionSchema)
|
||||
async def update_collection(
|
||||
collection_id: int,
|
||||
body: CollectionUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(selectinload(Collection.collection_bills))
|
||||
.where(Collection.id == collection_id)
|
||||
)
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if body.name is not None:
|
||||
name = body.name.strip()
|
||||
if not 1 <= len(name) <= 100:
|
||||
raise HTTPException(status_code=422, detail="name must be 1–100 characters")
|
||||
collection.name = name
|
||||
collection.slug = await _unique_slug(db, current_user.id, name, exclude_id=collection_id)
|
||||
|
||||
if body.is_public is not None:
|
||||
collection.is_public = body.is_public
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(collection)
|
||||
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(selectinload(Collection.collection_bills))
|
||||
.where(Collection.id == collection_id)
|
||||
)
|
||||
collection = result.scalar_one()
|
||||
return _to_schema(collection)
|
||||
|
||||
|
||||
# ── Delete ────────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.delete("/{collection_id}", status_code=204)
|
||||
async def delete_collection(
|
||||
collection_id: int,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Collection).where(Collection.id == collection_id))
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
await db.delete(collection)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ── Add bill ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/{collection_id}/bills/{bill_id}", status_code=204)
|
||||
async def add_bill_to_collection(
|
||||
collection_id: int,
|
||||
bill_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Collection).where(Collection.id == collection_id))
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
bill = await db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
raise HTTPException(status_code=404, detail="Bill not found")
|
||||
|
||||
existing = await db.execute(
|
||||
select(CollectionBill).where(
|
||||
CollectionBill.collection_id == collection_id,
|
||||
CollectionBill.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none():
|
||||
return # idempotent
|
||||
|
||||
db.add(CollectionBill(collection_id=collection_id, bill_id=bill_id))
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ── Remove bill ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.delete("/{collection_id}/bills/{bill_id}", status_code=204)
|
||||
async def remove_bill_from_collection(
|
||||
collection_id: int,
|
||||
bill_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(select(Collection).where(Collection.id == collection_id))
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
if collection.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
cb_result = await db.execute(
|
||||
select(CollectionBill).where(
|
||||
CollectionBill.collection_id == collection_id,
|
||||
CollectionBill.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
cb = cb_result.scalar_one_or_none()
|
||||
if not cb:
|
||||
raise HTTPException(status_code=404, detail="Bill not in collection")
|
||||
await db.delete(cb)
|
||||
await db.commit()
|
||||
121
backend/app/api/dashboard.py
Normal file
121
backend/app/api/dashboard.py
Normal file
@@ -0,0 +1,121 @@
|
||||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import APIRouter
|
||||
from sqlalchemy import desc, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.core.dependencies import get_optional_user
|
||||
from app.database import get_db
|
||||
from app.models import Bill, BillBrief, Follow, TrendScore
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import BillSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def _get_trending(db: AsyncSession) -> list[dict]:
|
||||
# Try progressively wider windows so stale scores still surface results
|
||||
for days_back in (1, 3, 7, 30):
|
||||
trending_result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.join(TrendScore, Bill.bill_id == TrendScore.bill_id)
|
||||
.where(TrendScore.score_date >= date.today() - timedelta(days=days_back))
|
||||
.order_by(desc(TrendScore.composite_score))
|
||||
.limit(10)
|
||||
)
|
||||
trending_bills = trending_result.scalars().unique().all()
|
||||
if trending_bills:
|
||||
return [_serialize_bill(b) for b in trending_bills]
|
||||
return []
|
||||
|
||||
|
||||
def _serialize_bill(bill: Bill) -> dict:
|
||||
b = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
b.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
b.latest_trend = bill.trend_scores[0]
|
||||
return b.model_dump()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_dashboard(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User | None = Depends(get_optional_user),
|
||||
):
|
||||
trending = await _get_trending(db)
|
||||
|
||||
if current_user is None:
|
||||
return {"feed": [], "trending": trending, "follows": {"bills": 0, "members": 0, "topics": 0}}
|
||||
|
||||
# Load follows for the current user
|
||||
follows_result = await db.execute(
|
||||
select(Follow).where(Follow.user_id == current_user.id)
|
||||
)
|
||||
follows = follows_result.scalars().all()
|
||||
|
||||
followed_bill_ids = [f.follow_value for f in follows if f.follow_type == "bill"]
|
||||
followed_member_ids = [f.follow_value for f in follows if f.follow_type == "member"]
|
||||
followed_topics = [f.follow_value for f in follows if f.follow_type == "topic"]
|
||||
|
||||
feed_bills: list[Bill] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
# 1. Directly followed bills
|
||||
if followed_bill_ids:
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.where(Bill.bill_id.in_(followed_bill_ids))
|
||||
.order_by(desc(Bill.latest_action_date))
|
||||
.limit(20)
|
||||
)
|
||||
for bill in result.scalars().all():
|
||||
if bill.bill_id not in seen_ids:
|
||||
feed_bills.append(bill)
|
||||
seen_ids.add(bill.bill_id)
|
||||
|
||||
# 2. Bills from followed members
|
||||
if followed_member_ids:
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.where(Bill.sponsor_id.in_(followed_member_ids))
|
||||
.order_by(desc(Bill.latest_action_date))
|
||||
.limit(20)
|
||||
)
|
||||
for bill in result.scalars().all():
|
||||
if bill.bill_id not in seen_ids:
|
||||
feed_bills.append(bill)
|
||||
seen_ids.add(bill.bill_id)
|
||||
|
||||
# 3. Bills matching followed topics (single query with OR across all topics)
|
||||
if followed_topics:
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.join(BillBrief, Bill.bill_id == BillBrief.bill_id)
|
||||
.where(or_(*[BillBrief.topic_tags.contains([t]) for t in followed_topics]))
|
||||
.order_by(desc(Bill.latest_action_date))
|
||||
.limit(20)
|
||||
)
|
||||
for bill in result.scalars().all():
|
||||
if bill.bill_id not in seen_ids:
|
||||
feed_bills.append(bill)
|
||||
seen_ids.add(bill.bill_id)
|
||||
|
||||
# Sort feed by latest action date
|
||||
feed_bills.sort(key=lambda b: b.latest_action_date or date.min, reverse=True)
|
||||
|
||||
return {
|
||||
"feed": [_serialize_bill(b) for b in feed_bills[:50]],
|
||||
"trending": trending,
|
||||
"follows": {
|
||||
"bills": len(followed_bill_ids),
|
||||
"members": len(followed_member_ids),
|
||||
"topics": len(followed_topics),
|
||||
},
|
||||
}
|
||||
94
backend/app/api/follows.py
Normal file
94
backend/app/api/follows.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models import Follow
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import FollowCreate, FollowModeUpdate, FollowSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
VALID_FOLLOW_TYPES = {"bill", "member", "topic"}
|
||||
VALID_MODES = {"neutral", "pocket_veto", "pocket_boost"}
|
||||
|
||||
|
||||
@router.get("", response_model=list[FollowSchema])
|
||||
async def list_follows(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Follow)
|
||||
.where(Follow.user_id == current_user.id)
|
||||
.order_by(Follow.created_at.desc())
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=FollowSchema, status_code=201)
|
||||
async def add_follow(
|
||||
body: FollowCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if body.follow_type not in VALID_FOLLOW_TYPES:
|
||||
raise HTTPException(status_code=400, detail=f"follow_type must be one of {VALID_FOLLOW_TYPES}")
|
||||
follow = Follow(
|
||||
user_id=current_user.id,
|
||||
follow_type=body.follow_type,
|
||||
follow_value=body.follow_value,
|
||||
)
|
||||
db.add(follow)
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(follow)
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
# Already following — return existing
|
||||
result = await db.execute(
|
||||
select(Follow).where(
|
||||
Follow.user_id == current_user.id,
|
||||
Follow.follow_type == body.follow_type,
|
||||
Follow.follow_value == body.follow_value,
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
return follow
|
||||
|
||||
|
||||
@router.patch("/{follow_id}/mode", response_model=FollowSchema)
|
||||
async def update_follow_mode(
|
||||
follow_id: int,
|
||||
body: FollowModeUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
if body.follow_mode not in VALID_MODES:
|
||||
raise HTTPException(status_code=400, detail=f"follow_mode must be one of {VALID_MODES}")
|
||||
follow = await db.get(Follow, follow_id)
|
||||
if not follow:
|
||||
raise HTTPException(status_code=404, detail="Follow not found")
|
||||
if follow.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not your follow")
|
||||
follow.follow_mode = body.follow_mode
|
||||
await db.commit()
|
||||
await db.refresh(follow)
|
||||
return follow
|
||||
|
||||
|
||||
@router.delete("/{follow_id}", status_code=204)
|
||||
async def remove_follow(
|
||||
follow_id: int,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
follow = await db.get(Follow, follow_id)
|
||||
if not follow:
|
||||
raise HTTPException(status_code=404, detail="Follow not found")
|
||||
if follow.user_id != current_user.id:
|
||||
raise HTTPException(status_code=403, detail="Not your follow")
|
||||
await db.delete(follow)
|
||||
await db.commit()
|
||||
43
backend/app/api/health.py
Normal file
43
backend/app/api/health.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import redis as redis_lib
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def health():
|
||||
return {"status": "ok", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@router.get("/detailed")
|
||||
async def health_detailed(db: AsyncSession = Depends(get_db)):
|
||||
# Check DB
|
||||
db_ok = False
|
||||
try:
|
||||
await db.execute(text("SELECT 1"))
|
||||
db_ok = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check Redis
|
||||
redis_ok = False
|
||||
try:
|
||||
r = redis_lib.from_url(settings.REDIS_URL)
|
||||
redis_ok = r.ping()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
status = "ok" if (db_ok and redis_ok) else "degraded"
|
||||
return {
|
||||
"status": status,
|
||||
"database": "ok" if db_ok else "error",
|
||||
"redis": "ok" if redis_ok else "error",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
313
backend/app/api/members.py
Normal file
313
backend/app/api/members.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
|
||||
_FIPS_TO_STATE = {
|
||||
"01": "AL", "02": "AK", "04": "AZ", "05": "AR", "06": "CA",
|
||||
"08": "CO", "09": "CT", "10": "DE", "11": "DC", "12": "FL",
|
||||
"13": "GA", "15": "HI", "16": "ID", "17": "IL", "18": "IN",
|
||||
"19": "IA", "20": "KS", "21": "KY", "22": "LA", "23": "ME",
|
||||
"24": "MD", "25": "MA", "26": "MI", "27": "MN", "28": "MS",
|
||||
"29": "MO", "30": "MT", "31": "NE", "32": "NV", "33": "NH",
|
||||
"34": "NJ", "35": "NM", "36": "NY", "37": "NC", "38": "ND",
|
||||
"39": "OH", "40": "OK", "41": "OR", "42": "PA", "44": "RI",
|
||||
"45": "SC", "46": "SD", "47": "TN", "48": "TX", "49": "UT",
|
||||
"50": "VT", "51": "VA", "53": "WA", "54": "WV", "55": "WI",
|
||||
"56": "WY", "60": "AS", "66": "GU", "69": "MP", "72": "PR", "78": "VI",
|
||||
}
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import desc, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, Member, MemberTrendScore, MemberNewsArticle
|
||||
from app.schemas.schemas import (
|
||||
BillSchema, MemberSchema, MemberTrendScoreSchema,
|
||||
MemberNewsArticleSchema, PaginatedResponse,
|
||||
)
|
||||
from app.services import congress_api
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/by-zip/{zip_code}", response_model=list[MemberSchema])
|
||||
async def get_members_by_zip(zip_code: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Return the House rep and senators for a ZIP code.
|
||||
Step 1: Nominatim (OpenStreetMap) — ZIP → lat/lng.
|
||||
Step 2: TIGERweb Legislative identify — lat/lng → congressional district.
|
||||
"""
|
||||
if not re.fullmatch(r"\d{5}", zip_code):
|
||||
raise HTTPException(status_code=400, detail="ZIP code must be 5 digits")
|
||||
|
||||
state_code: str | None = None
|
||||
district_num: str | None = None
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=20.0) as client:
|
||||
# Step 1: ZIP → lat/lng
|
||||
r1 = await client.get(
|
||||
"https://nominatim.openstreetmap.org/search",
|
||||
params={"postalcode": zip_code, "country": "US", "format": "json", "limit": "1"},
|
||||
headers={"User-Agent": "PocketVeto/1.0"},
|
||||
)
|
||||
places = r1.json() if r1.status_code == 200 else []
|
||||
if not places:
|
||||
logger.warning("Nominatim: no result for ZIP %s", zip_code)
|
||||
return []
|
||||
|
||||
lat = places[0]["lat"]
|
||||
lng = places[0]["lon"]
|
||||
|
||||
# Step 2: lat/lng → congressional district via TIGERweb identify (all layers)
|
||||
half = 0.5
|
||||
r2 = await client.get(
|
||||
"https://tigerweb.geo.census.gov/arcgis/rest/services/TIGERweb/Legislative/MapServer/identify",
|
||||
params={
|
||||
"f": "json",
|
||||
"geometry": f"{lng},{lat}",
|
||||
"geometryType": "esriGeometryPoint",
|
||||
"sr": "4326",
|
||||
"layers": "all",
|
||||
"tolerance": "2",
|
||||
"mapExtent": f"{float(lng)-half},{float(lat)-half},{float(lng)+half},{float(lat)+half}",
|
||||
"imageDisplay": "100,100,96",
|
||||
},
|
||||
)
|
||||
if r2.status_code != 200:
|
||||
logger.warning("TIGERweb returned %s for ZIP %s", r2.status_code, zip_code)
|
||||
return []
|
||||
|
||||
identify_results = r2.json().get("results", [])
|
||||
logger.info(
|
||||
"TIGERweb ZIP %s layers: %s",
|
||||
zip_code, [r.get("layerName") for r in identify_results],
|
||||
)
|
||||
|
||||
for item in identify_results:
|
||||
if "Congressional" not in (item.get("layerName") or ""):
|
||||
continue
|
||||
attrs = item.get("attributes", {})
|
||||
# GEOID = 2-char state FIPS + 2-char district (e.g. "1218" = FL-18)
|
||||
geoid = str(attrs.get("GEOID") or "").strip()
|
||||
if len(geoid) == 4:
|
||||
state_fips = geoid[:2]
|
||||
district_fips = geoid[2:]
|
||||
state_code = _FIPS_TO_STATE.get(state_fips)
|
||||
district_num = str(int(district_fips)) if district_fips.strip("0") else None
|
||||
if state_code:
|
||||
break
|
||||
|
||||
# Fallback: explicit field names
|
||||
cd_field = next((k for k in attrs if re.match(r"CD\d+FP$", k)), None)
|
||||
state_field = next((k for k in attrs if "STATEFP" in k.upper()), None)
|
||||
if cd_field and state_field:
|
||||
state_fips = str(attrs[state_field]).zfill(2)
|
||||
district_fips = str(attrs[cd_field])
|
||||
state_code = _FIPS_TO_STATE.get(state_fips)
|
||||
district_num = str(int(district_fips)) if district_fips.strip("0") else None
|
||||
if state_code:
|
||||
break
|
||||
|
||||
if not state_code:
|
||||
logger.warning(
|
||||
"ZIP %s: no CD found. Layers: %s",
|
||||
zip_code, [r.get("layerName") for r in identify_results],
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("ZIP lookup error for %s: %s", zip_code, exc)
|
||||
return []
|
||||
|
||||
if not state_code:
|
||||
return []
|
||||
|
||||
members: list[MemberSchema] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
if district_num:
|
||||
result = await db.execute(
|
||||
select(Member).where(
|
||||
Member.state == state_code,
|
||||
Member.district == district_num,
|
||||
Member.chamber == "House of Representatives",
|
||||
)
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
if member:
|
||||
seen.add(member.bioguide_id)
|
||||
members.append(MemberSchema.model_validate(member))
|
||||
else:
|
||||
# At-large states (AK, DE, MT, ND, SD, VT, WY)
|
||||
result = await db.execute(
|
||||
select(Member).where(
|
||||
Member.state == state_code,
|
||||
Member.chamber == "House of Representatives",
|
||||
).limit(1)
|
||||
)
|
||||
member = result.scalar_one_or_none()
|
||||
if member:
|
||||
seen.add(member.bioguide_id)
|
||||
members.append(MemberSchema.model_validate(member))
|
||||
|
||||
result = await db.execute(
|
||||
select(Member).where(
|
||||
Member.state == state_code,
|
||||
Member.chamber == "Senate",
|
||||
)
|
||||
)
|
||||
for member in result.scalars().all():
|
||||
if member.bioguide_id not in seen:
|
||||
seen.add(member.bioguide_id)
|
||||
members.append(MemberSchema.model_validate(member))
|
||||
|
||||
return members
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[MemberSchema])
|
||||
async def list_members(
|
||||
chamber: Optional[str] = Query(None),
|
||||
party: Optional[str] = Query(None),
|
||||
state: Optional[str] = Query(None),
|
||||
q: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(50, ge=1, le=250),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Member)
|
||||
if chamber:
|
||||
query = query.where(Member.chamber == chamber)
|
||||
if party:
|
||||
query = query.where(Member.party == party)
|
||||
if state:
|
||||
query = query.where(Member.state == state)
|
||||
if q:
|
||||
# name is stored as "Last, First" — also match "First Last" order
|
||||
first_last = func.concat(
|
||||
func.split_part(Member.name, ", ", 2), " ",
|
||||
func.split_part(Member.name, ", ", 1),
|
||||
)
|
||||
query = query.where(or_(
|
||||
Member.name.ilike(f"%{q}%"),
|
||||
first_last.ilike(f"%{q}%"),
|
||||
))
|
||||
|
||||
total = await db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||
query = query.order_by(Member.last_name, Member.first_name).offset((page - 1) * per_page).limit(per_page)
|
||||
|
||||
result = await db.execute(query)
|
||||
members = result.scalars().all()
|
||||
|
||||
return PaginatedResponse(
|
||||
items=members,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=max(1, (total + per_page - 1) // per_page),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}", response_model=MemberSchema)
|
||||
async def get_member(bioguide_id: str, db: AsyncSession = Depends(get_db)):
|
||||
member = await db.get(Member, bioguide_id)
|
||||
if not member:
|
||||
raise HTTPException(status_code=404, detail="Member not found")
|
||||
|
||||
# Kick off member interest on first view — single combined task avoids duplicate API calls
|
||||
if member.detail_fetched is None:
|
||||
try:
|
||||
from app.workers.member_interest import sync_member_interest
|
||||
sync_member_interest.delay(bioguide_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Lazy-enrich with detail data from Congress.gov on first view
|
||||
if member.detail_fetched is None:
|
||||
try:
|
||||
detail_raw = congress_api.get_member_detail(bioguide_id)
|
||||
enriched = congress_api.parse_member_detail_from_api(detail_raw)
|
||||
for field, value in enriched.items():
|
||||
if value is not None:
|
||||
setattr(member, field, value)
|
||||
member.detail_fetched = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
await db.refresh(member)
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not enrich member detail for {bioguide_id}: {e}")
|
||||
|
||||
# Attach latest trend score
|
||||
result_schema = MemberSchema.model_validate(member)
|
||||
latest_trend = (
|
||||
await db.execute(
|
||||
select(MemberTrendScore)
|
||||
.where(MemberTrendScore.member_id == bioguide_id)
|
||||
.order_by(desc(MemberTrendScore.score_date))
|
||||
.limit(1)
|
||||
)
|
||||
)
|
||||
trend = latest_trend.scalar_one_or_none()
|
||||
if trend:
|
||||
result_schema.latest_trend = MemberTrendScoreSchema.model_validate(trend)
|
||||
return result_schema
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}/trend", response_model=list[MemberTrendScoreSchema])
|
||||
async def get_member_trend(
|
||||
bioguide_id: str,
|
||||
days: int = Query(30, ge=7, le=365),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
from datetime import date, timedelta
|
||||
cutoff = date.today() - timedelta(days=days)
|
||||
result = await db.execute(
|
||||
select(MemberTrendScore)
|
||||
.where(MemberTrendScore.member_id == bioguide_id, MemberTrendScore.score_date >= cutoff)
|
||||
.order_by(MemberTrendScore.score_date)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}/news", response_model=list[MemberNewsArticleSchema])
|
||||
async def get_member_news(bioguide_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(MemberNewsArticle)
|
||||
.where(MemberNewsArticle.member_id == bioguide_id)
|
||||
.order_by(desc(MemberNewsArticle.published_at))
|
||||
.limit(20)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}/bills", response_model=PaginatedResponse[BillSchema])
|
||||
async def get_member_bills(
|
||||
bioguide_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Bill).options(selectinload(Bill.briefs), selectinload(Bill.sponsor)).where(Bill.sponsor_id == bioguide_id)
|
||||
total = await db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||
query = query.order_by(desc(Bill.introduced_date)).offset((page - 1) * per_page).limit(per_page)
|
||||
|
||||
result = await db.execute(query)
|
||||
bills = result.scalars().all()
|
||||
|
||||
items = []
|
||||
for bill in bills:
|
||||
b = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
b.latest_brief = bill.briefs[0]
|
||||
items.append(b)
|
||||
|
||||
return PaginatedResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=max(1, (total + per_page - 1) // per_page),
|
||||
)
|
||||
89
backend/app/api/notes.py
Normal file
89
backend/app/api/notes.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Bill Notes API — private per-user notes on individual bills.
|
||||
One note per (user, bill). PUT upserts, DELETE removes.
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.bill import Bill
|
||||
from app.models.note import BillNote
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import BillNoteSchema, BillNoteUpsert
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/{bill_id}", response_model=BillNoteSchema)
|
||||
async def get_note(
|
||||
bill_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(BillNote).where(
|
||||
BillNote.user_id == current_user.id,
|
||||
BillNote.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
note = result.scalar_one_or_none()
|
||||
if not note:
|
||||
raise HTTPException(status_code=404, detail="No note for this bill")
|
||||
return note
|
||||
|
||||
|
||||
@router.put("/{bill_id}", response_model=BillNoteSchema)
|
||||
async def upsert_note(
|
||||
bill_id: str,
|
||||
body: BillNoteUpsert,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
bill = await db.get(Bill, bill_id)
|
||||
if not bill:
|
||||
raise HTTPException(status_code=404, detail="Bill not found")
|
||||
|
||||
result = await db.execute(
|
||||
select(BillNote).where(
|
||||
BillNote.user_id == current_user.id,
|
||||
BillNote.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
note = result.scalar_one_or_none()
|
||||
|
||||
if note:
|
||||
note.content = body.content
|
||||
note.pinned = body.pinned
|
||||
else:
|
||||
note = BillNote(
|
||||
user_id=current_user.id,
|
||||
bill_id=bill_id,
|
||||
content=body.content,
|
||||
pinned=body.pinned,
|
||||
)
|
||||
db.add(note)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(note)
|
||||
return note
|
||||
|
||||
|
||||
@router.delete("/{bill_id}", status_code=204)
|
||||
async def delete_note(
|
||||
bill_id: str,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(BillNote).where(
|
||||
BillNote.user_id == current_user.id,
|
||||
BillNote.bill_id == bill_id,
|
||||
)
|
||||
)
|
||||
note = result.scalar_one_or_none()
|
||||
if not note:
|
||||
raise HTTPException(status_code=404, detail="No note for this bill")
|
||||
await db.delete(note)
|
||||
await db.commit()
|
||||
465
backend/app/api/notifications.py
Normal file
465
backend/app/api/notifications.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""
|
||||
Notifications API — user notification settings and per-user RSS feed.
|
||||
"""
|
||||
import base64
|
||||
import secrets
|
||||
from xml.etree.ElementTree import Element, SubElement, tostring
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import HTMLResponse, Response
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings as app_settings
|
||||
from app.core.crypto import decrypt_secret, encrypt_secret
|
||||
from app.core.dependencies import get_current_user
|
||||
from app.database import get_db
|
||||
from app.models.notification import NotificationEvent
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import (
|
||||
FollowModeTestRequest,
|
||||
NotificationEventSchema,
|
||||
NotificationSettingsResponse,
|
||||
NotificationSettingsUpdate,
|
||||
NotificationTestResult,
|
||||
NtfyTestRequest,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_EVENT_LABELS = {
|
||||
"new_document": "New Bill Text",
|
||||
"new_amendment": "Amendment Filed",
|
||||
"bill_updated": "Bill Updated",
|
||||
"weekly_digest": "Weekly Digest",
|
||||
}
|
||||
|
||||
|
||||
def _prefs_to_response(prefs: dict, rss_token: str | None) -> NotificationSettingsResponse:
|
||||
return NotificationSettingsResponse(
|
||||
ntfy_topic_url=prefs.get("ntfy_topic_url", ""),
|
||||
ntfy_auth_method=prefs.get("ntfy_auth_method", "none"),
|
||||
ntfy_token=prefs.get("ntfy_token", ""),
|
||||
ntfy_username=prefs.get("ntfy_username", ""),
|
||||
ntfy_password_set=bool(decrypt_secret(prefs.get("ntfy_password", ""))),
|
||||
ntfy_enabled=prefs.get("ntfy_enabled", False),
|
||||
rss_enabled=prefs.get("rss_enabled", False),
|
||||
rss_token=rss_token,
|
||||
email_enabled=prefs.get("email_enabled", False),
|
||||
email_address=prefs.get("email_address", ""),
|
||||
digest_enabled=prefs.get("digest_enabled", False),
|
||||
digest_frequency=prefs.get("digest_frequency", "daily"),
|
||||
quiet_hours_start=prefs.get("quiet_hours_start"),
|
||||
quiet_hours_end=prefs.get("quiet_hours_end"),
|
||||
timezone=prefs.get("timezone"),
|
||||
alert_filters=prefs.get("alert_filters"),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/settings", response_model=NotificationSettingsResponse)
|
||||
async def get_notification_settings(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
user = await db.get(User, current_user.id)
|
||||
# Auto-generate RSS token on first visit so the feed URL is always available
|
||||
if not user.rss_token:
|
||||
user.rss_token = secrets.token_urlsafe(32)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return _prefs_to_response(user.notification_prefs or {}, user.rss_token)
|
||||
|
||||
|
||||
@router.put("/settings", response_model=NotificationSettingsResponse)
|
||||
async def update_notification_settings(
|
||||
body: NotificationSettingsUpdate,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
user = await db.get(User, current_user.id)
|
||||
prefs = dict(user.notification_prefs or {})
|
||||
|
||||
if body.ntfy_topic_url is not None:
|
||||
prefs["ntfy_topic_url"] = body.ntfy_topic_url.strip()
|
||||
if body.ntfy_auth_method is not None:
|
||||
prefs["ntfy_auth_method"] = body.ntfy_auth_method
|
||||
if body.ntfy_token is not None:
|
||||
prefs["ntfy_token"] = body.ntfy_token.strip()
|
||||
if body.ntfy_username is not None:
|
||||
prefs["ntfy_username"] = body.ntfy_username.strip()
|
||||
if body.ntfy_password is not None:
|
||||
prefs["ntfy_password"] = encrypt_secret(body.ntfy_password.strip())
|
||||
if body.ntfy_enabled is not None:
|
||||
prefs["ntfy_enabled"] = body.ntfy_enabled
|
||||
if body.rss_enabled is not None:
|
||||
prefs["rss_enabled"] = body.rss_enabled
|
||||
if body.email_enabled is not None:
|
||||
prefs["email_enabled"] = body.email_enabled
|
||||
if body.email_address is not None:
|
||||
prefs["email_address"] = body.email_address.strip()
|
||||
if body.digest_enabled is not None:
|
||||
prefs["digest_enabled"] = body.digest_enabled
|
||||
if body.digest_frequency is not None:
|
||||
prefs["digest_frequency"] = body.digest_frequency
|
||||
if body.quiet_hours_start is not None:
|
||||
prefs["quiet_hours_start"] = body.quiet_hours_start
|
||||
if body.quiet_hours_end is not None:
|
||||
prefs["quiet_hours_end"] = body.quiet_hours_end
|
||||
if body.timezone is not None:
|
||||
prefs["timezone"] = body.timezone
|
||||
if body.alert_filters is not None:
|
||||
prefs["alert_filters"] = body.alert_filters
|
||||
# Allow clearing quiet hours by passing -1
|
||||
if body.quiet_hours_start == -1:
|
||||
prefs.pop("quiet_hours_start", None)
|
||||
prefs.pop("quiet_hours_end", None)
|
||||
prefs.pop("timezone", None)
|
||||
|
||||
user.notification_prefs = prefs
|
||||
|
||||
if not user.rss_token:
|
||||
user.rss_token = secrets.token_urlsafe(32)
|
||||
# Generate unsubscribe token the first time an email address is saved
|
||||
if prefs.get("email_address") and not user.email_unsubscribe_token:
|
||||
user.email_unsubscribe_token = secrets.token_urlsafe(32)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return _prefs_to_response(user.notification_prefs or {}, user.rss_token)
|
||||
|
||||
|
||||
@router.post("/settings/rss-reset", response_model=NotificationSettingsResponse)
|
||||
async def reset_rss_token(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Regenerate the RSS token, invalidating the old feed URL."""
|
||||
user = await db.get(User, current_user.id)
|
||||
user.rss_token = secrets.token_urlsafe(32)
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
return _prefs_to_response(user.notification_prefs or {}, user.rss_token)
|
||||
|
||||
|
||||
@router.post("/test/ntfy", response_model=NotificationTestResult)
|
||||
async def test_ntfy(
|
||||
body: NtfyTestRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Send a test push notification to verify ntfy settings."""
|
||||
url = body.ntfy_topic_url.strip()
|
||||
if not url:
|
||||
return NotificationTestResult(status="error", detail="Topic URL is required")
|
||||
|
||||
base_url = (app_settings.PUBLIC_URL or app_settings.LOCAL_URL).rstrip("/")
|
||||
headers: dict[str, str] = {
|
||||
"Title": "PocketVeto: Test Notification",
|
||||
"Priority": "default",
|
||||
"Tags": "white_check_mark",
|
||||
"Click": f"{base_url}/notifications",
|
||||
}
|
||||
if body.ntfy_auth_method == "token" and body.ntfy_token.strip():
|
||||
headers["Authorization"] = f"Bearer {body.ntfy_token.strip()}"
|
||||
elif body.ntfy_auth_method == "basic" and body.ntfy_username.strip():
|
||||
creds = base64.b64encode(
|
||||
f"{body.ntfy_username.strip()}:{body.ntfy_password}".encode()
|
||||
).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.post(
|
||||
url,
|
||||
content=(
|
||||
"Your PocketVeto notification settings are working correctly. "
|
||||
"Real alerts will link directly to the relevant bill page."
|
||||
).encode("utf-8"),
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return NotificationTestResult(status="ok", detail=f"Test notification sent (HTTP {resp.status_code})")
|
||||
except httpx.HTTPStatusError as e:
|
||||
return NotificationTestResult(status="error", detail=f"HTTP {e.response.status_code}: {e.response.text[:200]}")
|
||||
except httpx.RequestError as e:
|
||||
return NotificationTestResult(status="error", detail=f"Connection error: {e}")
|
||||
|
||||
|
||||
@router.post("/test/email", response_model=NotificationTestResult)
|
||||
async def test_email(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Send a test email to the user's configured email address."""
|
||||
import smtplib
|
||||
from email.mime.text import MIMEText
|
||||
|
||||
user = await db.get(User, current_user.id)
|
||||
prefs = user.notification_prefs or {}
|
||||
email_addr = prefs.get("email_address", "").strip()
|
||||
if not email_addr:
|
||||
return NotificationTestResult(status="error", detail="No email address saved. Save your address first.")
|
||||
|
||||
if not app_settings.SMTP_HOST:
|
||||
return NotificationTestResult(status="error", detail="SMTP not configured on this server. Set SMTP_HOST in .env")
|
||||
|
||||
try:
|
||||
from_addr = app_settings.SMTP_FROM or app_settings.SMTP_USER
|
||||
base_url = (app_settings.PUBLIC_URL or app_settings.LOCAL_URL).rstrip("/")
|
||||
body = (
|
||||
"This is a test email from PocketVeto.\n\n"
|
||||
"Your email notification settings are working correctly. "
|
||||
"Real alerts will include bill titles, summaries, and direct links.\n\n"
|
||||
f"Visit your notifications page: {base_url}/notifications"
|
||||
)
|
||||
msg = MIMEText(body, "plain", "utf-8")
|
||||
msg["Subject"] = "PocketVeto: Test Email Notification"
|
||||
msg["From"] = from_addr
|
||||
msg["To"] = email_addr
|
||||
|
||||
use_ssl = app_settings.SMTP_PORT == 465
|
||||
if use_ssl:
|
||||
ctx = smtplib.SMTP_SSL(app_settings.SMTP_HOST, app_settings.SMTP_PORT, timeout=10)
|
||||
else:
|
||||
ctx = smtplib.SMTP(app_settings.SMTP_HOST, app_settings.SMTP_PORT, timeout=10)
|
||||
with ctx as s:
|
||||
if not use_ssl and app_settings.SMTP_STARTTLS:
|
||||
s.starttls()
|
||||
if app_settings.SMTP_USER:
|
||||
s.login(app_settings.SMTP_USER, app_settings.SMTP_PASSWORD)
|
||||
s.sendmail(from_addr, [email_addr], msg.as_string())
|
||||
|
||||
return NotificationTestResult(status="ok", detail=f"Test email sent to {email_addr}")
|
||||
except smtplib.SMTPAuthenticationError:
|
||||
return NotificationTestResult(status="error", detail="SMTP authentication failed — check SMTP_USER and SMTP_PASSWORD in .env")
|
||||
except smtplib.SMTPConnectError:
|
||||
return NotificationTestResult(status="error", detail=f"Could not connect to {app_settings.SMTP_HOST}:{app_settings.SMTP_PORT}")
|
||||
except Exception as e:
|
||||
return NotificationTestResult(status="error", detail=str(e))
|
||||
|
||||
|
||||
@router.post("/test/rss", response_model=NotificationTestResult)
|
||||
async def test_rss(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Verify the user's RSS feed is reachable and return its event count."""
|
||||
user = await db.get(User, current_user.id)
|
||||
if not user.rss_token:
|
||||
return NotificationTestResult(status="error", detail="RSS token not generated — save settings first")
|
||||
|
||||
count_result = await db.execute(
|
||||
select(NotificationEvent).where(NotificationEvent.user_id == user.id)
|
||||
)
|
||||
event_count = len(count_result.scalars().all())
|
||||
|
||||
return NotificationTestResult(
|
||||
status="ok",
|
||||
detail=f"RSS feed is active with {event_count} event{'s' if event_count != 1 else ''}. Subscribe to the URL shown above.",
|
||||
event_count=event_count,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/history", response_model=list[NotificationEventSchema])
|
||||
async def get_notification_history(
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Return the 50 most recent notification events for the current user."""
|
||||
result = await db.execute(
|
||||
select(NotificationEvent)
|
||||
.where(NotificationEvent.user_id == current_user.id)
|
||||
.order_by(NotificationEvent.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("/test/follow-mode", response_model=NotificationTestResult)
|
||||
async def test_follow_mode(
|
||||
body: FollowModeTestRequest,
|
||||
current_user: User = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
"""Simulate dispatcher behaviour for a given follow mode + event type."""
|
||||
from sqlalchemy import select as sa_select
|
||||
from app.models.follow import Follow
|
||||
|
||||
VALID_MODES = {"pocket_veto", "pocket_boost"}
|
||||
VALID_EVENTS = {"new_document", "new_amendment", "bill_updated"}
|
||||
if body.mode not in VALID_MODES:
|
||||
return NotificationTestResult(status="error", detail=f"mode must be one of {VALID_MODES}")
|
||||
if body.event_type not in VALID_EVENTS:
|
||||
return NotificationTestResult(status="error", detail=f"event_type must be one of {VALID_EVENTS}")
|
||||
|
||||
result = await db.execute(
|
||||
sa_select(Follow).where(
|
||||
Follow.user_id == current_user.id,
|
||||
Follow.follow_type == "bill",
|
||||
).limit(1)
|
||||
)
|
||||
follow = result.scalar_one_or_none()
|
||||
if not follow:
|
||||
return NotificationTestResult(
|
||||
status="error",
|
||||
detail="No bill follows found — follow at least one bill first",
|
||||
)
|
||||
|
||||
# Pocket Veto suppression: brief events are silently dropped
|
||||
if body.mode == "pocket_veto" and body.event_type in ("new_document", "new_amendment"):
|
||||
return NotificationTestResult(
|
||||
status="ok",
|
||||
detail=(
|
||||
f"✓ Suppressed — Pocket Veto correctly blocked a '{body.event_type}' event. "
|
||||
"No ntfy was sent (this is the expected behaviour)."
|
||||
),
|
||||
)
|
||||
|
||||
# Everything else would send ntfy — check the user has it configured
|
||||
user = await db.get(User, current_user.id)
|
||||
prefs = user.notification_prefs or {}
|
||||
ntfy_url = prefs.get("ntfy_topic_url", "").strip()
|
||||
ntfy_enabled = prefs.get("ntfy_enabled", False)
|
||||
if not ntfy_enabled or not ntfy_url:
|
||||
return NotificationTestResult(
|
||||
status="error",
|
||||
detail="ntfy not configured or disabled — enable it in Notification Settings first.",
|
||||
)
|
||||
|
||||
bill_url = f"{(app_settings.PUBLIC_URL or app_settings.LOCAL_URL).rstrip('/')}/bills/{follow.follow_value}"
|
||||
event_titles = {
|
||||
"new_document": "New Bill Text",
|
||||
"new_amendment": "Amendment Filed",
|
||||
"bill_updated": "Bill Updated",
|
||||
}
|
||||
mode_label = body.mode.replace("_", " ").title()
|
||||
headers: dict[str, str] = {
|
||||
"Title": f"[{mode_label} Test] {event_titles[body.event_type]}: {follow.follow_value.upper()}",
|
||||
"Priority": "default",
|
||||
"Tags": "test_tube",
|
||||
"Click": bill_url,
|
||||
}
|
||||
if body.mode == "pocket_boost":
|
||||
headers["Actions"] = (
|
||||
f"view, View Bill, {bill_url}; "
|
||||
"view, Find Your Rep, https://www.house.gov/representatives/find-your-representative"
|
||||
)
|
||||
|
||||
auth_method = prefs.get("ntfy_auth_method", "none")
|
||||
ntfy_token = prefs.get("ntfy_token", "").strip()
|
||||
ntfy_username = prefs.get("ntfy_username", "").strip()
|
||||
ntfy_password = prefs.get("ntfy_password", "").strip()
|
||||
if auth_method == "token" and ntfy_token:
|
||||
headers["Authorization"] = f"Bearer {ntfy_token}"
|
||||
elif auth_method == "basic" and ntfy_username:
|
||||
creds = base64.b64encode(f"{ntfy_username}:{ntfy_password}".encode()).decode()
|
||||
headers["Authorization"] = f"Basic {creds}"
|
||||
|
||||
message_lines = [
|
||||
f"This is a test of {mode_label} mode for bill {follow.follow_value.upper()}.",
|
||||
f"Event type: {event_titles[body.event_type]}",
|
||||
]
|
||||
if body.mode == "pocket_boost":
|
||||
message_lines.append("Tap the action buttons below to view the bill or find your representative.")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10) as client:
|
||||
resp = await client.post(
|
||||
ntfy_url,
|
||||
content="\n".join(message_lines).encode("utf-8"),
|
||||
headers=headers,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
detail = f"✓ ntfy sent (HTTP {resp.status_code})"
|
||||
if body.mode == "pocket_boost":
|
||||
detail += " — check your phone for 'View Bill' and 'Find Your Rep' action buttons"
|
||||
return NotificationTestResult(status="ok", detail=detail)
|
||||
except httpx.HTTPStatusError as e:
|
||||
return NotificationTestResult(status="error", detail=f"HTTP {e.response.status_code}: {e.response.text[:200]}")
|
||||
except httpx.RequestError as e:
|
||||
return NotificationTestResult(status="error", detail=f"Connection error: {e}")
|
||||
|
||||
|
||||
@router.get("/unsubscribe/{token}", response_class=HTMLResponse, include_in_schema=False)
|
||||
async def email_unsubscribe(token: str, db: AsyncSession = Depends(get_db)):
|
||||
"""One-click email unsubscribe — no login required."""
|
||||
result = await db.execute(
|
||||
select(User).where(User.email_unsubscribe_token == token)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
return HTMLResponse(
|
||||
_unsubscribe_page("Invalid or expired link", success=False),
|
||||
status_code=404,
|
||||
)
|
||||
|
||||
prefs = dict(user.notification_prefs or {})
|
||||
prefs["email_enabled"] = False
|
||||
user.notification_prefs = prefs
|
||||
await db.commit()
|
||||
|
||||
return HTMLResponse(_unsubscribe_page("You've been unsubscribed from PocketVeto email notifications.", success=True))
|
||||
|
||||
|
||||
def _unsubscribe_page(message: str, success: bool) -> str:
|
||||
color = "#16a34a" if success else "#dc2626"
|
||||
icon = "✓" if success else "✗"
|
||||
return f"""<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head><meta charset="utf-8"><meta name="viewport" content="width=device-width,initial-scale=1">
|
||||
<title>PocketVeto — Unsubscribe</title>
|
||||
<style>
|
||||
body{{font-family:system-ui,sans-serif;background:#f9fafb;display:flex;align-items:center;justify-content:center;min-height:100vh;margin:0}}
|
||||
.card{{background:#fff;border:1px solid #e5e7eb;border-radius:12px;padding:2.5rem;max-width:420px;width:100%;text-align:center;box-shadow:0 1px 3px rgba(0,0,0,.08)}}
|
||||
.icon{{font-size:2.5rem;color:{color};margin-bottom:1rem}}
|
||||
h1{{font-size:1.1rem;font-weight:600;color:#111827;margin:0 0 .5rem}}
|
||||
p{{font-size:.9rem;color:#6b7280;margin:0 0 1.5rem;line-height:1.5}}
|
||||
a{{color:#2563eb;text-decoration:none;font-size:.875rem}}a:hover{{text-decoration:underline}}
|
||||
</style></head>
|
||||
<body><div class="card">
|
||||
<div class="icon">{icon}</div>
|
||||
<h1>Email Notifications</h1>
|
||||
<p>{message}</p>
|
||||
<a href="/">Return to PocketVeto</a>
|
||||
</div></body></html>"""
|
||||
|
||||
|
||||
@router.get("/feed/{rss_token}.xml", include_in_schema=False)
|
||||
async def rss_feed(rss_token: str, db: AsyncSession = Depends(get_db)):
|
||||
"""Public tokenized RSS feed — no auth required."""
|
||||
result = await db.execute(select(User).where(User.rss_token == rss_token))
|
||||
user = result.scalar_one_or_none()
|
||||
if not user:
|
||||
raise HTTPException(status_code=404, detail="Feed not found")
|
||||
|
||||
events_result = await db.execute(
|
||||
select(NotificationEvent)
|
||||
.where(NotificationEvent.user_id == user.id)
|
||||
.order_by(NotificationEvent.created_at.desc())
|
||||
.limit(50)
|
||||
)
|
||||
events = events_result.scalars().all()
|
||||
return Response(content=_build_rss(events), media_type="application/rss+xml")
|
||||
|
||||
|
||||
def _build_rss(events: list) -> bytes:
|
||||
rss = Element("rss", version="2.0")
|
||||
channel = SubElement(rss, "channel")
|
||||
SubElement(channel, "title").text = "PocketVeto — Bill Alerts"
|
||||
SubElement(channel, "description").text = "Updates on your followed bills"
|
||||
SubElement(channel, "language").text = "en-us"
|
||||
|
||||
for event in events:
|
||||
payload = event.payload or {}
|
||||
item = SubElement(channel, "item")
|
||||
label = _EVENT_LABELS.get(event.event_type, "Update")
|
||||
bill_label = payload.get("bill_label", event.bill_id.upper())
|
||||
SubElement(item, "title").text = f"{label}: {bill_label} — {payload.get('bill_title', '')}"
|
||||
SubElement(item, "description").text = payload.get("brief_summary", "")
|
||||
if payload.get("bill_url"):
|
||||
SubElement(item, "link").text = payload["bill_url"]
|
||||
SubElement(item, "pubDate").text = event.created_at.strftime("%a, %d %b %Y %H:%M:%S +0000")
|
||||
SubElement(item, "guid").text = str(event.id)
|
||||
|
||||
return tostring(rss, encoding="unicode").encode("utf-8")
|
||||
60
backend/app/api/search.py
Normal file
60
backend/app/api/search.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import func, or_, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, Member
|
||||
from app.schemas.schemas import BillSchema, MemberSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search(
|
||||
q: str = Query(..., min_length=2, max_length=500),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
# Bill ID direct match
|
||||
id_results = await db.execute(
|
||||
select(Bill).where(Bill.bill_id.ilike(f"%{q}%")).limit(20)
|
||||
)
|
||||
id_bills = id_results.scalars().all()
|
||||
|
||||
# Full-text search on title/content via tsvector
|
||||
fts_results = await db.execute(
|
||||
select(Bill)
|
||||
.where(text("search_vector @@ plainto_tsquery('english', :q)"))
|
||||
.order_by(text("ts_rank(search_vector, plainto_tsquery('english', :q)) DESC"))
|
||||
.limit(20)
|
||||
.params(q=q)
|
||||
)
|
||||
fts_bills = fts_results.scalars().all()
|
||||
|
||||
# Merge, dedup, preserve order (ID matches first)
|
||||
seen = set()
|
||||
bills = []
|
||||
for b in id_bills + fts_bills:
|
||||
if b.bill_id not in seen:
|
||||
seen.add(b.bill_id)
|
||||
bills.append(b)
|
||||
|
||||
# Fuzzy member search — matches "Last, First" and "First Last"
|
||||
first_last = func.concat(
|
||||
func.split_part(Member.name, ", ", 2), " ",
|
||||
func.split_part(Member.name, ", ", 1),
|
||||
)
|
||||
member_results = await db.execute(
|
||||
select(Member)
|
||||
.where(or_(
|
||||
Member.name.ilike(f"%{q}%"),
|
||||
first_last.ilike(f"%{q}%"),
|
||||
))
|
||||
.order_by(Member.last_name)
|
||||
.limit(10)
|
||||
)
|
||||
members = member_results.scalars().all()
|
||||
|
||||
return {
|
||||
"bills": [BillSchema.model_validate(b) for b in bills],
|
||||
"members": [MemberSchema.model_validate(m) for m in members],
|
||||
}
|
||||
225
backend/app/api/settings.py
Normal file
225
backend/app/api/settings.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.core.dependencies import get_current_admin, get_current_user
|
||||
from app.database import get_db
|
||||
from app.models import AppSetting
|
||||
from app.models.user import User
|
||||
from app.schemas.schemas import SettingUpdate, SettingsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsResponse)
|
||||
async def get_settings(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_user),
|
||||
):
|
||||
"""Return current effective settings (env + DB overrides)."""
|
||||
# DB overrides take precedence over env vars
|
||||
overrides: dict[str, str] = {}
|
||||
result = await db.execute(select(AppSetting))
|
||||
for row in result.scalars().all():
|
||||
overrides[row.key] = row.value
|
||||
|
||||
return SettingsResponse(
|
||||
llm_provider=overrides.get("llm_provider", settings.LLM_PROVIDER),
|
||||
llm_model=overrides.get("llm_model", _current_model(overrides.get("llm_provider", settings.LLM_PROVIDER))),
|
||||
congress_poll_interval_minutes=int(overrides.get("congress_poll_interval_minutes", settings.CONGRESS_POLL_INTERVAL_MINUTES)),
|
||||
newsapi_enabled=bool(settings.NEWSAPI_KEY),
|
||||
pytrends_enabled=settings.PYTRENDS_ENABLED,
|
||||
api_keys_configured={
|
||||
"openai": bool(settings.OPENAI_API_KEY),
|
||||
"anthropic": bool(settings.ANTHROPIC_API_KEY),
|
||||
"gemini": bool(settings.GEMINI_API_KEY),
|
||||
"ollama": True, # no API key required
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def update_setting(
|
||||
body: SettingUpdate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Update a runtime setting."""
|
||||
ALLOWED_KEYS = {"llm_provider", "llm_model", "congress_poll_interval_minutes"}
|
||||
if body.key not in ALLOWED_KEYS:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=400, detail=f"Allowed setting keys: {ALLOWED_KEYS}")
|
||||
|
||||
existing = await db.get(AppSetting, body.key)
|
||||
if existing:
|
||||
existing.value = body.value
|
||||
else:
|
||||
db.add(AppSetting(key=body.key, value=body.value))
|
||||
await db.commit()
|
||||
return {"key": body.key, "value": body.value}
|
||||
|
||||
|
||||
@router.post("/test-llm")
|
||||
async def test_llm_connection(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Ping the configured LLM provider with a minimal request."""
|
||||
import asyncio
|
||||
prov_row = await db.get(AppSetting, "llm_provider")
|
||||
model_row = await db.get(AppSetting, "llm_model")
|
||||
provider_name = prov_row.value if prov_row else settings.LLM_PROVIDER
|
||||
model_name = model_row.value if model_row else None
|
||||
try:
|
||||
return await asyncio.to_thread(_ping_provider, provider_name, model_name)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
_PING = "Reply with exactly three words: Connection test successful."
|
||||
|
||||
|
||||
def _ping_provider(provider_name: str, model_name: str | None) -> dict:
|
||||
if provider_name == "openai":
|
||||
from openai import OpenAI
|
||||
model = model_name or settings.OPENAI_MODEL
|
||||
client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
resp = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": _PING}],
|
||||
max_tokens=20,
|
||||
)
|
||||
reply = resp.choices[0].message.content.strip()
|
||||
return {"status": "ok", "provider": "openai", "model": model, "reply": reply}
|
||||
|
||||
if provider_name == "anthropic":
|
||||
import anthropic
|
||||
model = model_name or settings.ANTHROPIC_MODEL
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
resp = client.messages.create(
|
||||
model=model,
|
||||
max_tokens=20,
|
||||
messages=[{"role": "user", "content": _PING}],
|
||||
)
|
||||
reply = resp.content[0].text.strip()
|
||||
return {"status": "ok", "provider": "anthropic", "model": model, "reply": reply}
|
||||
|
||||
if provider_name == "gemini":
|
||||
import google.generativeai as genai
|
||||
model = model_name or settings.GEMINI_MODEL
|
||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||
resp = genai.GenerativeModel(model_name=model).generate_content(_PING)
|
||||
reply = resp.text.strip()
|
||||
return {"status": "ok", "provider": "gemini", "model": model, "reply": reply}
|
||||
|
||||
if provider_name == "ollama":
|
||||
import requests as req
|
||||
model = model_name or settings.OLLAMA_MODEL
|
||||
resp = req.post(
|
||||
f"{settings.OLLAMA_BASE_URL}/api/generate",
|
||||
json={"model": model, "prompt": _PING, "stream": False},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
reply = resp.json().get("response", "").strip()
|
||||
return {"status": "ok", "provider": "ollama", "model": model, "reply": reply}
|
||||
|
||||
raise ValueError(f"Unknown provider: {provider_name}")
|
||||
|
||||
|
||||
@router.get("/llm-models")
|
||||
async def list_llm_models(
|
||||
provider: str,
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Fetch available models directly from the provider's API."""
|
||||
import asyncio
|
||||
handlers = {
|
||||
"openai": _list_openai_models,
|
||||
"anthropic": _list_anthropic_models,
|
||||
"gemini": _list_gemini_models,
|
||||
"ollama": _list_ollama_models,
|
||||
}
|
||||
fn = handlers.get(provider)
|
||||
if not fn:
|
||||
return {"models": [], "error": f"Unknown provider: {provider}"}
|
||||
try:
|
||||
return await asyncio.to_thread(fn)
|
||||
except Exception as exc:
|
||||
return {"models": [], "error": str(exc)}
|
||||
|
||||
|
||||
def _list_openai_models() -> dict:
|
||||
from openai import OpenAI
|
||||
if not settings.OPENAI_API_KEY:
|
||||
return {"models": [], "error": "OPENAI_API_KEY not configured"}
|
||||
client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
all_models = client.models.list().data
|
||||
CHAT_PREFIXES = ("gpt-", "o1", "o3", "o4", "chatgpt-")
|
||||
EXCLUDE = ("realtime", "audio", "tts", "whisper", "embedding", "dall-e", "instruct")
|
||||
filtered = sorted(
|
||||
[m.id for m in all_models
|
||||
if any(m.id.startswith(p) for p in CHAT_PREFIXES)
|
||||
and not any(x in m.id for x in EXCLUDE)],
|
||||
reverse=True,
|
||||
)
|
||||
return {"models": [{"id": m, "name": m} for m in filtered]}
|
||||
|
||||
|
||||
def _list_anthropic_models() -> dict:
|
||||
import requests as req
|
||||
if not settings.ANTHROPIC_API_KEY:
|
||||
return {"models": [], "error": "ANTHROPIC_API_KEY not configured"}
|
||||
resp = req.get(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={
|
||||
"x-api-key": settings.ANTHROPIC_API_KEY,
|
||||
"anthropic-version": "2023-06-01",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return {
|
||||
"models": [
|
||||
{"id": m["id"], "name": m.get("display_name", m["id"])}
|
||||
for m in data.get("data", [])
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _list_gemini_models() -> dict:
|
||||
import google.generativeai as genai
|
||||
if not settings.GEMINI_API_KEY:
|
||||
return {"models": [], "error": "GEMINI_API_KEY not configured"}
|
||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||
models = [
|
||||
{"id": m.name.replace("models/", ""), "name": m.display_name}
|
||||
for m in genai.list_models()
|
||||
if "generateContent" in m.supported_generation_methods
|
||||
]
|
||||
return {"models": sorted(models, key=lambda x: x["id"])}
|
||||
|
||||
|
||||
def _list_ollama_models() -> dict:
|
||||
import requests as req
|
||||
try:
|
||||
resp = req.get(f"{settings.OLLAMA_BASE_URL}/api/tags", timeout=5)
|
||||
resp.raise_for_status()
|
||||
tags = resp.json().get("models", [])
|
||||
return {"models": [{"id": m["name"], "name": m["name"]} for m in tags]}
|
||||
except Exception as exc:
|
||||
return {"models": [], "error": f"Ollama unreachable: {exc}"}
|
||||
|
||||
|
||||
def _current_model(provider: str) -> str:
|
||||
if provider == "openai":
|
||||
return settings.OPENAI_MODEL
|
||||
elif provider == "anthropic":
|
||||
return settings.ANTHROPIC_MODEL
|
||||
elif provider == "gemini":
|
||||
return settings.GEMINI_MODEL
|
||||
elif provider == "ollama":
|
||||
return settings.OLLAMA_MODEL
|
||||
return "unknown"
|
||||
113
backend/app/api/share.py
Normal file
113
backend/app/api/share.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""
|
||||
Public share router — no authentication required.
|
||||
Serves shareable read-only views for briefs and collections.
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models.bill import Bill, BillDocument
|
||||
from app.models.brief import BillBrief
|
||||
from app.models.collection import Collection, CollectionBill
|
||||
from app.schemas.schemas import (
|
||||
BillSchema,
|
||||
BriefSchema,
|
||||
BriefShareResponse,
|
||||
CollectionDetailSchema,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Brief share ───────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/brief/{token}", response_model=BriefShareResponse)
|
||||
async def get_shared_brief(
|
||||
token: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(BillBrief)
|
||||
.options(
|
||||
selectinload(BillBrief.bill).selectinload(Bill.sponsor),
|
||||
selectinload(BillBrief.bill).selectinload(Bill.briefs),
|
||||
selectinload(BillBrief.bill).selectinload(Bill.trend_scores),
|
||||
)
|
||||
.where(BillBrief.share_token == token)
|
||||
)
|
||||
brief = result.scalar_one_or_none()
|
||||
if not brief:
|
||||
raise HTTPException(status_code=404, detail="Brief not found")
|
||||
|
||||
bill = brief.bill
|
||||
bill_schema = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
bill_schema.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
bill_schema.latest_trend = bill.trend_scores[0]
|
||||
|
||||
doc_result = await db.execute(
|
||||
select(BillDocument.bill_id).where(BillDocument.bill_id == bill.bill_id).limit(1)
|
||||
)
|
||||
bill_schema.has_document = doc_result.scalar_one_or_none() is not None
|
||||
|
||||
return BriefShareResponse(
|
||||
brief=BriefSchema.model_validate(brief),
|
||||
bill=bill_schema,
|
||||
)
|
||||
|
||||
|
||||
# ── Collection share ──────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/collection/{token}", response_model=CollectionDetailSchema)
|
||||
async def get_shared_collection(
|
||||
token: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
result = await db.execute(
|
||||
select(Collection)
|
||||
.options(
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.briefs),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.trend_scores),
|
||||
selectinload(Collection.collection_bills).selectinload(CollectionBill.bill).selectinload(Bill.sponsor),
|
||||
)
|
||||
.where(Collection.share_token == token)
|
||||
)
|
||||
collection = result.scalar_one_or_none()
|
||||
if not collection:
|
||||
raise HTTPException(status_code=404, detail="Collection not found")
|
||||
|
||||
cb_list = collection.collection_bills
|
||||
bills = [cb.bill for cb in cb_list]
|
||||
bill_ids = [b.bill_id for b in bills]
|
||||
|
||||
if bill_ids:
|
||||
doc_result = await db.execute(
|
||||
select(BillDocument.bill_id).where(BillDocument.bill_id.in_(bill_ids)).distinct()
|
||||
)
|
||||
bills_with_docs = {row[0] for row in doc_result}
|
||||
else:
|
||||
bills_with_docs = set()
|
||||
|
||||
bill_schemas = []
|
||||
for bill in bills:
|
||||
bs = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
bs.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
bs.latest_trend = bill.trend_scores[0]
|
||||
bs.has_document = bill.bill_id in bills_with_docs
|
||||
bill_schemas.append(bs)
|
||||
|
||||
return CollectionDetailSchema(
|
||||
id=collection.id,
|
||||
name=collection.name,
|
||||
slug=collection.slug,
|
||||
is_public=collection.is_public,
|
||||
share_token=collection.share_token,
|
||||
bill_count=len(cb_list),
|
||||
created_at=collection.created_at,
|
||||
bills=bill_schemas,
|
||||
)
|
||||
Reference in New Issue
Block a user