112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
"""
|
||
Trend scorer — calculates the daily zeitgeist score for bills.
|
||
Runs nightly via Celery Beat.
|
||
"""
|
||
import logging
|
||
from datetime import date, timedelta
|
||
|
||
from sqlalchemy import and_
|
||
|
||
from app.database import get_sync_db
|
||
from app.models import Bill, BillBrief, TrendScore
|
||
from app.services import news_service, trends_service
|
||
from app.workers.celery_app import celery_app
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def calculate_composite_score(newsapi_count: int, gnews_count: int, gtrends_score: float) -> float:
|
||
"""
|
||
Weighted composite score (0–100):
|
||
NewsAPI article count → 0–40 pts (saturates at 20 articles)
|
||
Google News RSS count → 0–30 pts (saturates at 50 articles)
|
||
Google Trends score → 0–30 pts (0–100 input)
|
||
"""
|
||
newsapi_pts = min(newsapi_count / 20, 1.0) * 40
|
||
gnews_pts = min(gnews_count / 50, 1.0) * 30
|
||
gtrends_pts = (gtrends_score / 100) * 30
|
||
return round(newsapi_pts + gnews_pts + gtrends_pts, 2)
|
||
|
||
|
||
@celery_app.task(bind=True, name="app.workers.trend_scorer.calculate_all_trend_scores")
|
||
def calculate_all_trend_scores(self):
|
||
"""Nightly task: calculate trend scores for bills active in the last 90 days."""
|
||
db = get_sync_db()
|
||
try:
|
||
cutoff = date.today() - timedelta(days=90)
|
||
active_bills = (
|
||
db.query(Bill)
|
||
.filter(Bill.latest_action_date >= cutoff)
|
||
.all()
|
||
)
|
||
|
||
scored = 0
|
||
today = date.today()
|
||
|
||
for bill in active_bills:
|
||
# Skip if already scored today
|
||
existing = (
|
||
db.query(TrendScore)
|
||
.filter_by(bill_id=bill.bill_id, score_date=today)
|
||
.first()
|
||
)
|
||
if existing:
|
||
continue
|
||
|
||
# Get latest brief for topic tags
|
||
latest_brief = (
|
||
db.query(BillBrief)
|
||
.filter_by(bill_id=bill.bill_id)
|
||
.order_by(BillBrief.created_at.desc())
|
||
.first()
|
||
)
|
||
topic_tags = latest_brief.topic_tags if latest_brief else []
|
||
|
||
# Build search query
|
||
query = news_service.build_news_query(
|
||
bill_title=bill.title,
|
||
short_title=bill.short_title,
|
||
sponsor_name=None,
|
||
bill_type=bill.bill_type,
|
||
bill_number=bill.bill_number,
|
||
)
|
||
|
||
# Fetch counts
|
||
newsapi_articles = news_service.fetch_newsapi_articles(query, days=30)
|
||
newsapi_count = len(newsapi_articles)
|
||
gnews_count = news_service.fetch_gnews_count(query, days=30)
|
||
|
||
# Google Trends
|
||
keywords = trends_service.keywords_for_bill(
|
||
title=bill.title or "",
|
||
short_title=bill.short_title or "",
|
||
topic_tags=topic_tags,
|
||
)
|
||
gtrends_score = trends_service.get_trends_score(keywords)
|
||
|
||
composite = calculate_composite_score(newsapi_count, gnews_count, gtrends_score)
|
||
|
||
db.add(TrendScore(
|
||
bill_id=bill.bill_id,
|
||
score_date=today,
|
||
newsapi_count=newsapi_count,
|
||
gnews_count=gnews_count,
|
||
gtrends_score=gtrends_score,
|
||
composite_score=composite,
|
||
))
|
||
scored += 1
|
||
|
||
if scored % 20 == 0:
|
||
db.commit()
|
||
|
||
db.commit()
|
||
logger.info(f"Scored {scored} bills")
|
||
return {"scored": scored}
|
||
|
||
except Exception as exc:
|
||
db.rollback()
|
||
logger.error(f"Trend scoring failed: {exc}")
|
||
raise
|
||
finally:
|
||
db.close()
|