Initial commit
This commit is contained in:
0
backend/app/api/__init__.py
Normal file
0
backend/app/api/__init__.py
Normal file
39
backend/app/api/admin.py
Normal file
39
backend/app/api/admin.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/trigger-poll")
|
||||
async def trigger_poll():
|
||||
"""Manually trigger a Congress.gov poll without waiting for the Beat schedule."""
|
||||
from app.workers.congress_poller import poll_congress_bills
|
||||
task = poll_congress_bills.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-member-sync")
|
||||
async def trigger_member_sync():
|
||||
"""Manually trigger a member sync."""
|
||||
from app.workers.congress_poller import sync_members
|
||||
task = sync_members.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.post("/trigger-trend-scores")
|
||||
async def trigger_trend_scores():
|
||||
"""Manually trigger trend score calculation."""
|
||||
from app.workers.trend_scorer import calculate_all_trend_scores
|
||||
task = calculate_all_trend_scores.delay()
|
||||
return {"task_id": task.id, "status": "queued"}
|
||||
|
||||
|
||||
@router.get("/task-status/{task_id}")
|
||||
async def get_task_status(task_id: str):
|
||||
"""Check the status of an async task."""
|
||||
from app.workers.celery_app import celery_app
|
||||
result = celery_app.AsyncResult(task_id)
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": result.status,
|
||||
"result": result.result if result.ready() else None,
|
||||
}
|
||||
145
backend/app/api/bills.py
Normal file
145
backend/app/api/bills.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import desc, func, or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, BillAction, BillBrief, NewsArticle, TrendScore
|
||||
from app.schemas.schemas import (
|
||||
BillDetailSchema,
|
||||
BillSchema,
|
||||
BillActionSchema,
|
||||
NewsArticleSchema,
|
||||
PaginatedResponse,
|
||||
TrendScoreSchema,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[BillSchema])
|
||||
async def list_bills(
|
||||
chamber: Optional[str] = Query(None),
|
||||
topic: Optional[str] = Query(None),
|
||||
sponsor_id: Optional[str] = Query(None),
|
||||
q: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
sort: str = Query("latest_action_date"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = (
|
||||
select(Bill)
|
||||
.options(
|
||||
selectinload(Bill.sponsor),
|
||||
selectinload(Bill.briefs),
|
||||
selectinload(Bill.trend_scores),
|
||||
)
|
||||
)
|
||||
|
||||
if chamber:
|
||||
query = query.where(Bill.chamber == chamber)
|
||||
if sponsor_id:
|
||||
query = query.where(Bill.sponsor_id == sponsor_id)
|
||||
if topic:
|
||||
query = query.join(BillBrief, Bill.bill_id == BillBrief.bill_id).where(
|
||||
BillBrief.topic_tags.contains([topic])
|
||||
)
|
||||
if q:
|
||||
query = query.where(
|
||||
or_(
|
||||
Bill.bill_id.ilike(f"%{q}%"),
|
||||
Bill.title.ilike(f"%{q}%"),
|
||||
Bill.short_title.ilike(f"%{q}%"),
|
||||
)
|
||||
)
|
||||
|
||||
# Count total
|
||||
count_query = select(func.count()).select_from(query.subquery())
|
||||
total = await db.scalar(count_query) or 0
|
||||
|
||||
# Sort
|
||||
sort_col = getattr(Bill, sort, Bill.latest_action_date)
|
||||
query = query.order_by(desc(sort_col)).offset((page - 1) * per_page).limit(per_page)
|
||||
|
||||
result = await db.execute(query)
|
||||
bills = result.scalars().unique().all()
|
||||
|
||||
# Attach latest brief and trend to each bill
|
||||
items = []
|
||||
for bill in bills:
|
||||
bill_dict = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
bill_dict.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
bill_dict.latest_trend = bill.trend_scores[0]
|
||||
items.append(bill_dict)
|
||||
|
||||
return PaginatedResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=max(1, (total + per_page - 1) // per_page),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{bill_id}", response_model=BillDetailSchema)
|
||||
async def get_bill(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(
|
||||
selectinload(Bill.sponsor),
|
||||
selectinload(Bill.actions),
|
||||
selectinload(Bill.briefs),
|
||||
selectinload(Bill.news_articles),
|
||||
selectinload(Bill.trend_scores),
|
||||
)
|
||||
.where(Bill.bill_id == bill_id)
|
||||
)
|
||||
bill = result.scalar_one_or_none()
|
||||
if not bill:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=404, detail="Bill not found")
|
||||
|
||||
detail = BillDetailSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
detail.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
detail.latest_trend = bill.trend_scores[0]
|
||||
return detail
|
||||
|
||||
|
||||
@router.get("/{bill_id}/actions", response_model=list[BillActionSchema])
|
||||
async def get_bill_actions(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(BillAction)
|
||||
.where(BillAction.bill_id == bill_id)
|
||||
.order_by(desc(BillAction.action_date))
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bill_id}/news", response_model=list[NewsArticleSchema])
|
||||
async def get_bill_news(bill_id: str, db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(
|
||||
select(NewsArticle)
|
||||
.where(NewsArticle.bill_id == bill_id)
|
||||
.order_by(desc(NewsArticle.published_at))
|
||||
.limit(20)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.get("/{bill_id}/trend", response_model=list[TrendScoreSchema])
|
||||
async def get_bill_trend(bill_id: str, days: int = Query(30, ge=7, le=365), db: AsyncSession = Depends(get_db)):
|
||||
from datetime import date, timedelta
|
||||
cutoff = date.today() - timedelta(days=days)
|
||||
result = await db.execute(
|
||||
select(TrendScore)
|
||||
.where(TrendScore.bill_id == bill_id, TrendScore.score_date >= cutoff)
|
||||
.order_by(TrendScore.score_date)
|
||||
)
|
||||
return result.scalars().all()
|
||||
102
backend/app/api/dashboard.py
Normal file
102
backend/app/api/dashboard.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from datetime import date, timedelta
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import APIRouter
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, BillBrief, Follow, TrendScore
|
||||
from app.schemas.schemas import BillSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def get_dashboard(db: AsyncSession = Depends(get_db)):
|
||||
# Load all follows
|
||||
follows_result = await db.execute(select(Follow))
|
||||
follows = follows_result.scalars().all()
|
||||
|
||||
followed_bill_ids = [f.follow_value for f in follows if f.follow_type == "bill"]
|
||||
followed_member_ids = [f.follow_value for f in follows if f.follow_type == "member"]
|
||||
followed_topics = [f.follow_value for f in follows if f.follow_type == "topic"]
|
||||
|
||||
feed_bills: list[Bill] = []
|
||||
seen_ids: set[str] = set()
|
||||
|
||||
# 1. Directly followed bills
|
||||
if followed_bill_ids:
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.where(Bill.bill_id.in_(followed_bill_ids))
|
||||
.order_by(desc(Bill.latest_action_date))
|
||||
.limit(20)
|
||||
)
|
||||
for bill in result.scalars().all():
|
||||
if bill.bill_id not in seen_ids:
|
||||
feed_bills.append(bill)
|
||||
seen_ids.add(bill.bill_id)
|
||||
|
||||
# 2. Bills from followed members
|
||||
if followed_member_ids:
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.where(Bill.sponsor_id.in_(followed_member_ids))
|
||||
.order_by(desc(Bill.latest_action_date))
|
||||
.limit(20)
|
||||
)
|
||||
for bill in result.scalars().all():
|
||||
if bill.bill_id not in seen_ids:
|
||||
feed_bills.append(bill)
|
||||
seen_ids.add(bill.bill_id)
|
||||
|
||||
# 3. Bills matching followed topics
|
||||
for topic in followed_topics:
|
||||
result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.join(BillBrief, Bill.bill_id == BillBrief.bill_id)
|
||||
.where(BillBrief.topic_tags.contains([topic]))
|
||||
.order_by(desc(Bill.latest_action_date))
|
||||
.limit(10)
|
||||
)
|
||||
for bill in result.scalars().all():
|
||||
if bill.bill_id not in seen_ids:
|
||||
feed_bills.append(bill)
|
||||
seen_ids.add(bill.bill_id)
|
||||
|
||||
# Sort feed by latest action date
|
||||
feed_bills.sort(key=lambda b: b.latest_action_date or date.min, reverse=True)
|
||||
|
||||
# 4. Trending bills (top 10 by composite score today)
|
||||
trending_result = await db.execute(
|
||||
select(Bill)
|
||||
.options(selectinload(Bill.sponsor), selectinload(Bill.briefs), selectinload(Bill.trend_scores))
|
||||
.join(TrendScore, Bill.bill_id == TrendScore.bill_id)
|
||||
.where(TrendScore.score_date >= date.today() - timedelta(days=1))
|
||||
.order_by(desc(TrendScore.composite_score))
|
||||
.limit(10)
|
||||
)
|
||||
trending_bills = trending_result.scalars().unique().all()
|
||||
|
||||
def serialize_bill(bill: Bill) -> dict:
|
||||
b = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
b.latest_brief = bill.briefs[0]
|
||||
if bill.trend_scores:
|
||||
b.latest_trend = bill.trend_scores[0]
|
||||
return b.model_dump()
|
||||
|
||||
return {
|
||||
"feed": [serialize_bill(b) for b in feed_bills[:50]],
|
||||
"trending": [serialize_bill(b) for b in trending_bills],
|
||||
"follows": {
|
||||
"bills": len(followed_bill_ids),
|
||||
"members": len(followed_member_ids),
|
||||
"topics": len(followed_topics),
|
||||
},
|
||||
}
|
||||
49
backend/app/api/follows.py
Normal file
49
backend/app/api/follows.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Follow
|
||||
from app.schemas.schemas import FollowCreate, FollowSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
VALID_FOLLOW_TYPES = {"bill", "member", "topic"}
|
||||
|
||||
|
||||
@router.get("", response_model=list[FollowSchema])
|
||||
async def list_follows(db: AsyncSession = Depends(get_db)):
|
||||
result = await db.execute(select(Follow).order_by(Follow.created_at.desc()))
|
||||
return result.scalars().all()
|
||||
|
||||
|
||||
@router.post("", response_model=FollowSchema, status_code=201)
|
||||
async def add_follow(body: FollowCreate, db: AsyncSession = Depends(get_db)):
|
||||
if body.follow_type not in VALID_FOLLOW_TYPES:
|
||||
raise HTTPException(status_code=400, detail=f"follow_type must be one of {VALID_FOLLOW_TYPES}")
|
||||
follow = Follow(follow_type=body.follow_type, follow_value=body.follow_value)
|
||||
db.add(follow)
|
||||
try:
|
||||
await db.commit()
|
||||
await db.refresh(follow)
|
||||
except IntegrityError:
|
||||
await db.rollback()
|
||||
# Already following — return existing
|
||||
result = await db.execute(
|
||||
select(Follow).where(
|
||||
Follow.follow_type == body.follow_type,
|
||||
Follow.follow_value == body.follow_value,
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
return follow
|
||||
|
||||
|
||||
@router.delete("/{follow_id}", status_code=204)
|
||||
async def remove_follow(follow_id: int, db: AsyncSession = Depends(get_db)):
|
||||
follow = await db.get(Follow, follow_id)
|
||||
if not follow:
|
||||
raise HTTPException(status_code=404, detail="Follow not found")
|
||||
await db.delete(follow)
|
||||
await db.commit()
|
||||
43
backend/app/api/health.py
Normal file
43
backend/app/api/health.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import redis as redis_lib
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def health():
|
||||
return {"status": "ok", "timestamp": datetime.now(timezone.utc).isoformat()}
|
||||
|
||||
|
||||
@router.get("/detailed")
|
||||
async def health_detailed(db: AsyncSession = Depends(get_db)):
|
||||
# Check DB
|
||||
db_ok = False
|
||||
try:
|
||||
await db.execute(text("SELECT 1"))
|
||||
db_ok = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Check Redis
|
||||
redis_ok = False
|
||||
try:
|
||||
r = redis_lib.from_url(settings.REDIS_URL)
|
||||
redis_ok = r.ping()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
status = "ok" if (db_ok and redis_ok) else "degraded"
|
||||
return {
|
||||
"status": status,
|
||||
"database": "ok" if db_ok else "error",
|
||||
"redis": "ok" if redis_ok else "error",
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
85
backend/app/api/members.py
Normal file
85
backend/app/api/members.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy import desc, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, Member
|
||||
from app.schemas.schemas import BillSchema, MemberSchema, PaginatedResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=PaginatedResponse[MemberSchema])
|
||||
async def list_members(
|
||||
chamber: Optional[str] = Query(None),
|
||||
party: Optional[str] = Query(None),
|
||||
state: Optional[str] = Query(None),
|
||||
q: Optional[str] = Query(None),
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(50, ge=1, le=250),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Member)
|
||||
if chamber:
|
||||
query = query.where(Member.chamber == chamber)
|
||||
if party:
|
||||
query = query.where(Member.party == party)
|
||||
if state:
|
||||
query = query.where(Member.state == state)
|
||||
if q:
|
||||
query = query.where(Member.name.ilike(f"%{q}%"))
|
||||
|
||||
total = await db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||
query = query.order_by(Member.last_name, Member.first_name).offset((page - 1) * per_page).limit(per_page)
|
||||
|
||||
result = await db.execute(query)
|
||||
members = result.scalars().all()
|
||||
|
||||
return PaginatedResponse(
|
||||
items=members,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=max(1, (total + per_page - 1) // per_page),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}", response_model=MemberSchema)
|
||||
async def get_member(bioguide_id: str, db: AsyncSession = Depends(get_db)):
|
||||
member = await db.get(Member, bioguide_id)
|
||||
if not member:
|
||||
raise HTTPException(status_code=404, detail="Member not found")
|
||||
return member
|
||||
|
||||
|
||||
@router.get("/{bioguide_id}/bills", response_model=PaginatedResponse[BillSchema])
|
||||
async def get_member_bills(
|
||||
bioguide_id: str,
|
||||
page: int = Query(1, ge=1),
|
||||
per_page: int = Query(20, ge=1, le=100),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
query = select(Bill).options(selectinload(Bill.briefs)).where(Bill.sponsor_id == bioguide_id)
|
||||
total = await db.scalar(select(func.count()).select_from(query.subquery())) or 0
|
||||
query = query.order_by(desc(Bill.introduced_date)).offset((page - 1) * per_page).limit(per_page)
|
||||
|
||||
result = await db.execute(query)
|
||||
bills = result.scalars().all()
|
||||
|
||||
items = []
|
||||
for bill in bills:
|
||||
b = BillSchema.model_validate(bill)
|
||||
if bill.briefs:
|
||||
b.latest_brief = bill.briefs[0]
|
||||
items.append(b)
|
||||
|
||||
return PaginatedResponse(
|
||||
items=items,
|
||||
total=total,
|
||||
page=page,
|
||||
per_page=per_page,
|
||||
pages=max(1, (total + per_page - 1) // per_page),
|
||||
)
|
||||
53
backend/app/api/search.py
Normal file
53
backend/app/api/search.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy import select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.database import get_db
|
||||
from app.models import Bill, Member
|
||||
from app.schemas.schemas import BillSchema, MemberSchema
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def search(
|
||||
q: str = Query(..., min_length=2),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
):
|
||||
# Bill ID direct match
|
||||
id_results = await db.execute(
|
||||
select(Bill).where(Bill.bill_id.ilike(f"%{q}%")).limit(20)
|
||||
)
|
||||
id_bills = id_results.scalars().all()
|
||||
|
||||
# Full-text search on title/content via tsvector
|
||||
fts_results = await db.execute(
|
||||
select(Bill)
|
||||
.where(text("search_vector @@ plainto_tsquery('english', :q)"))
|
||||
.order_by(text("ts_rank(search_vector, plainto_tsquery('english', :q)) DESC"))
|
||||
.limit(20)
|
||||
.params(q=q)
|
||||
)
|
||||
fts_bills = fts_results.scalars().all()
|
||||
|
||||
# Merge, dedup, preserve order (ID matches first)
|
||||
seen = set()
|
||||
bills = []
|
||||
for b in id_bills + fts_bills:
|
||||
if b.bill_id not in seen:
|
||||
seen.add(b.bill_id)
|
||||
bills.append(b)
|
||||
|
||||
# Fuzzy member search
|
||||
member_results = await db.execute(
|
||||
select(Member)
|
||||
.where(Member.name.ilike(f"%{q}%"))
|
||||
.order_by(Member.last_name)
|
||||
.limit(10)
|
||||
)
|
||||
members = member_results.scalars().all()
|
||||
|
||||
return {
|
||||
"bills": [BillSchema.model_validate(b) for b in bills],
|
||||
"members": [MemberSchema.model_validate(m) for m in members],
|
||||
}
|
||||
86
backend/app/api/settings.py
Normal file
86
backend/app/api/settings.py
Normal file
@@ -0,0 +1,86 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import settings
|
||||
from app.database import get_db
|
||||
from app.models import AppSetting
|
||||
from app.schemas.schemas import SettingUpdate, SettingsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=SettingsResponse)
|
||||
async def get_settings(db: AsyncSession = Depends(get_db)):
|
||||
"""Return current effective settings (env + DB overrides)."""
|
||||
# DB overrides take precedence over env vars
|
||||
overrides: dict[str, str] = {}
|
||||
result = await db.execute(select(AppSetting))
|
||||
for row in result.scalars().all():
|
||||
overrides[row.key] = row.value
|
||||
|
||||
return SettingsResponse(
|
||||
llm_provider=overrides.get("llm_provider", settings.LLM_PROVIDER),
|
||||
llm_model=overrides.get("llm_model", _current_model(overrides.get("llm_provider", settings.LLM_PROVIDER))),
|
||||
congress_poll_interval_minutes=int(overrides.get("congress_poll_interval_minutes", settings.CONGRESS_POLL_INTERVAL_MINUTES)),
|
||||
newsapi_enabled=bool(settings.NEWSAPI_KEY),
|
||||
pytrends_enabled=settings.PYTRENDS_ENABLED,
|
||||
)
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def update_setting(body: SettingUpdate, db: AsyncSession = Depends(get_db)):
|
||||
"""Update a runtime setting."""
|
||||
ALLOWED_KEYS = {"llm_provider", "llm_model", "congress_poll_interval_minutes"}
|
||||
if body.key not in ALLOWED_KEYS:
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=400, detail=f"Allowed setting keys: {ALLOWED_KEYS}")
|
||||
|
||||
existing = await db.get(AppSetting, body.key)
|
||||
if existing:
|
||||
existing.value = body.value
|
||||
else:
|
||||
db.add(AppSetting(key=body.key, value=body.value))
|
||||
await db.commit()
|
||||
return {"key": body.key, "value": body.value}
|
||||
|
||||
|
||||
@router.post("/test-llm")
|
||||
async def test_llm_connection():
|
||||
"""Test that the configured LLM provider responds correctly."""
|
||||
from app.services.llm_service import get_llm_provider
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
brief = provider.generate_brief(
|
||||
doc_text="This is a test bill for connection verification purposes.",
|
||||
bill_metadata={
|
||||
"title": "Test Connection Bill",
|
||||
"sponsor_name": "Test Sponsor",
|
||||
"party": "Test",
|
||||
"state": "DC",
|
||||
"chamber": "House",
|
||||
"introduced_date": "2025-01-01",
|
||||
"latest_action_text": "Test action",
|
||||
"latest_action_date": "2025-01-01",
|
||||
},
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"provider": brief.llm_provider,
|
||||
"model": brief.llm_model,
|
||||
"summary_preview": brief.summary[:100] + "..." if len(brief.summary) > 100 else brief.summary,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "detail": str(e)}
|
||||
|
||||
|
||||
def _current_model(provider: str) -> str:
|
||||
if provider == "openai":
|
||||
return settings.OPENAI_MODEL
|
||||
elif provider == "anthropic":
|
||||
return settings.ANTHROPIC_MODEL
|
||||
elif provider == "gemini":
|
||||
return settings.GEMINI_MODEL
|
||||
elif provider == "ollama":
|
||||
return settings.OLLAMA_MODEL
|
||||
return "unknown"
|
||||
Reference in New Issue
Block a user