feat: LLM Batch API — OpenAI + Anthropic 50% cost reduction (v0.9.8)

Submit up to 1000 unbriefed documents to the provider Batch API in one
shot instead of individual synchronous LLM calls. Results are polled
every 30 minutes via a new Celery beat task and imported automatically.

- New worker: llm_batch_processor.py
  - submit_llm_batch: guards against duplicate batches, builds JSONL
    (OpenAI) or request list (Anthropic), stores state in AppSetting
  - poll_llm_batch_results: checks batch status, imports completed
    results with idempotency, emits notifications + triggers news fetch
- celery_app: register worker, route to llm queue, beat every 30 min
- admin API: POST /submit-llm-batch + GET /llm-batch-status endpoints
- Frontend: submitLlmBatch + getLlmBatchStatus in adminAPI; settings
  page shows batch control row (openai/anthropic only) with live
  progress line while batch is processing

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Jack Levy
2026-03-14 17:35:15 -04:00
parent 7e5c5b473e
commit cba19c7bb3
5 changed files with 467 additions and 0 deletions

View File

@@ -0,0 +1,401 @@
"""
LLM Batch processor — submits and polls OpenAI/Anthropic Batch API jobs.
50% cheaper than synchronous calls; 24-hour processing window.
New bills still use the synchronous llm_processor task.
"""
import io
import json
import logging
from datetime import datetime
from sqlalchemy import text
from app.config import settings
from app.database import get_sync_db
from app.models import Bill, BillBrief, BillDocument, Member
from app.models.setting import AppSetting
from app.services.llm_service import (
AMENDMENT_SYSTEM_PROMPT,
MAX_TOKENS_DEFAULT,
SYSTEM_PROMPT,
build_amendment_prompt,
build_prompt,
parse_brief_json,
)
from app.workers.celery_app import celery_app
logger = logging.getLogger(__name__)
_BATCH_SETTING_KEY = "llm_active_batch"
# ── State helpers ──────────────────────────────────────────────────────────────
def _save_batch_state(db, state: dict):
row = db.get(AppSetting, _BATCH_SETTING_KEY)
if row:
row.value = json.dumps(state)
else:
row = AppSetting(key=_BATCH_SETTING_KEY, value=json.dumps(state))
db.add(row)
db.commit()
def _clear_batch_state(db):
row = db.get(AppSetting, _BATCH_SETTING_KEY)
if row:
db.delete(row)
db.commit()
# ── Request builder ────────────────────────────────────────────────────────────
def _build_request_data(db, doc_id: int, bill_id: str) -> tuple[str, str, str]:
"""Returns (custom_id, system_prompt, user_prompt) for a document."""
doc = db.get(BillDocument, doc_id)
if not doc or not doc.raw_text:
raise ValueError(f"Document {doc_id} missing or has no text")
bill = db.get(Bill, bill_id)
if not bill:
raise ValueError(f"Bill {bill_id} not found")
sponsor = db.get(Member, bill.sponsor_id) if bill.sponsor_id else None
bill_metadata = {
"title": bill.title or "Unknown Title",
"sponsor_name": sponsor.name if sponsor else "Unknown",
"party": sponsor.party if sponsor else "Unknown",
"state": sponsor.state if sponsor else "Unknown",
"chamber": bill.chamber or "Unknown",
"introduced_date": str(bill.introduced_date) if bill.introduced_date else "Unknown",
"latest_action_text": bill.latest_action_text or "None",
"latest_action_date": str(bill.latest_action_date) if bill.latest_action_date else "Unknown",
}
previous_full_brief = (
db.query(BillBrief)
.filter_by(bill_id=bill_id, brief_type="full")
.order_by(BillBrief.created_at.desc())
.first()
)
if previous_full_brief and previous_full_brief.document_id:
previous_doc = db.get(BillDocument, previous_full_brief.document_id)
if previous_doc and previous_doc.raw_text:
brief_type = "amendment"
prompt = build_amendment_prompt(doc.raw_text, previous_doc.raw_text, bill_metadata, MAX_TOKENS_DEFAULT)
system_prompt = AMENDMENT_SYSTEM_PROMPT + "\n\nIMPORTANT: Respond with ONLY valid JSON. No other text."
else:
brief_type = "full"
prompt = build_prompt(doc.raw_text, bill_metadata, MAX_TOKENS_DEFAULT)
system_prompt = SYSTEM_PROMPT
else:
brief_type = "full"
prompt = build_prompt(doc.raw_text, bill_metadata, MAX_TOKENS_DEFAULT)
system_prompt = SYSTEM_PROMPT
custom_id = f"doc-{doc_id}-{brief_type}"
return custom_id, system_prompt, prompt
# ── Submit task ────────────────────────────────────────────────────────────────
@celery_app.task(bind=True, name="app.workers.llm_batch_processor.submit_llm_batch")
def submit_llm_batch(self):
"""Submit all unbriefed documents to the OpenAI or Anthropic Batch API."""
db = get_sync_db()
try:
prov_row = db.get(AppSetting, "llm_provider")
model_row = db.get(AppSetting, "llm_model")
provider_name = ((prov_row.value if prov_row else None) or settings.LLM_PROVIDER).lower()
if provider_name not in ("openai", "anthropic"):
return {"status": "unsupported", "provider": provider_name}
# Check for already-active batch
active_row = db.get(AppSetting, _BATCH_SETTING_KEY)
if active_row:
try:
active = json.loads(active_row.value)
if active.get("status") == "processing":
return {"status": "already_active", "batch_id": active.get("batch_id")}
except Exception:
pass
# Find docs with text but no brief
rows = db.execute(text("""
SELECT bd.id AS doc_id, bd.bill_id, bd.govinfo_url
FROM bill_documents bd
LEFT JOIN bill_briefs bb ON bb.document_id = bd.id
WHERE bd.raw_text IS NOT NULL AND bb.id IS NULL
LIMIT 1000
""")).fetchall()
if not rows:
return {"status": "nothing_to_process"}
doc_ids = [r.doc_id for r in rows]
if provider_name == "openai":
model = (model_row.value if model_row else None) or settings.OPENAI_MODEL
batch_id = _submit_openai_batch(db, rows, model)
else:
model = (model_row.value if model_row else None) or settings.ANTHROPIC_MODEL
batch_id = _submit_anthropic_batch(db, rows, model)
state = {
"batch_id": batch_id,
"provider": provider_name,
"model": model,
"doc_ids": doc_ids,
"doc_count": len(doc_ids),
"submitted_at": datetime.utcnow().isoformat(),
"status": "processing",
}
_save_batch_state(db, state)
logger.info(f"Submitted {len(doc_ids)}-doc batch to {provider_name}: {batch_id}")
return {"status": "submitted", "batch_id": batch_id, "doc_count": len(doc_ids)}
finally:
db.close()
def _submit_openai_batch(db, rows, model: str) -> str:
from openai import OpenAI
client = OpenAI(api_key=settings.OPENAI_API_KEY)
lines = []
for row in rows:
try:
custom_id, system_prompt, prompt = _build_request_data(db, row.doc_id, row.bill_id)
except Exception as exc:
logger.warning(f"Skipping doc {row.doc_id}: {exc}")
continue
lines.append(json.dumps({
"custom_id": custom_id,
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
"response_format": {"type": "json_object"},
"temperature": 0.1,
"max_tokens": MAX_TOKENS_DEFAULT,
},
}))
jsonl_bytes = "\n".join(lines).encode()
file_obj = client.files.create(
file=("batch.jsonl", io.BytesIO(jsonl_bytes), "application/jsonl"),
purpose="batch",
)
batch = client.batches.create(
input_file_id=file_obj.id,
endpoint="/v1/chat/completions",
completion_window="24h",
)
return batch.id
def _submit_anthropic_batch(db, rows, model: str) -> str:
import anthropic
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
requests = []
for row in rows:
try:
custom_id, system_prompt, prompt = _build_request_data(db, row.doc_id, row.bill_id)
except Exception as exc:
logger.warning(f"Skipping doc {row.doc_id}: {exc}")
continue
requests.append({
"custom_id": custom_id,
"params": {
"model": model,
"max_tokens": 4096,
"system": [{"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}],
"messages": [{"role": "user", "content": prompt}],
},
})
batch = client.messages.batches.create(requests=requests)
return batch.id
# ── Poll task ──────────────────────────────────────────────────────────────────
@celery_app.task(bind=True, name="app.workers.llm_batch_processor.poll_llm_batch_results")
def poll_llm_batch_results(self):
"""Check active batch status and import completed results (runs every 30 min via beat)."""
db = get_sync_db()
try:
active_row = db.get(AppSetting, _BATCH_SETTING_KEY)
if not active_row:
return {"status": "no_active_batch"}
try:
state = json.loads(active_row.value)
except Exception:
_clear_batch_state(db)
return {"status": "invalid_state"}
batch_id = state["batch_id"]
provider_name = state["provider"]
model = state["model"]
if provider_name == "openai":
return _poll_openai(db, state, batch_id, model)
elif provider_name == "anthropic":
return _poll_anthropic(db, state, batch_id, model)
else:
_clear_batch_state(db)
return {"status": "unknown_provider"}
finally:
db.close()
# ── Result processing helpers ──────────────────────────────────────────────────
def _save_brief(db, doc_id: int, bill_id: str, brief, brief_type: str, govinfo_url) -> bool:
"""Idempotency check + save. Returns True if saved, False if already exists."""
if db.query(BillBrief).filter_by(document_id=doc_id).first():
return False
db_brief = BillBrief(
bill_id=bill_id,
document_id=doc_id,
brief_type=brief_type,
summary=brief.summary,
key_points=brief.key_points,
risks=brief.risks,
deadlines=brief.deadlines,
topic_tags=brief.topic_tags,
llm_provider=brief.llm_provider,
llm_model=brief.llm_model,
govinfo_url=govinfo_url,
)
db.add(db_brief)
db.commit()
db.refresh(db_brief)
return True
def _emit_notifications_and_news(db, bill_id: str, brief, brief_type: str):
bill = db.get(Bill, bill_id)
if not bill:
return
from app.workers.notification_utils import (
emit_bill_notification,
emit_member_follow_notifications,
emit_topic_follow_notifications,
)
event_type = "new_amendment" if brief_type == "amendment" else "new_document"
emit_bill_notification(db, bill, event_type, brief.summary)
emit_member_follow_notifications(db, bill, event_type, brief.summary)
emit_topic_follow_notifications(db, bill, event_type, brief.summary, brief.topic_tags or [])
from app.workers.news_fetcher import fetch_news_for_bill
fetch_news_for_bill.delay(bill_id)
def _parse_custom_id(custom_id: str) -> tuple[int, str]:
"""Parse 'doc-{doc_id}-{brief_type}' → (doc_id, brief_type)."""
parts = custom_id.split("-")
return int(parts[1]), parts[2]
def _poll_openai(db, state: dict, batch_id: str, model: str) -> dict:
from openai import OpenAI
client = OpenAI(api_key=settings.OPENAI_API_KEY)
batch = client.batches.retrieve(batch_id)
logger.info(f"OpenAI batch {batch_id} status: {batch.status}")
if batch.status in ("failed", "cancelled", "expired"):
_clear_batch_state(db)
return {"status": batch.status}
if batch.status != "completed":
return {"status": "processing", "batch_status": batch.status}
content = client.files.content(batch.output_file_id).read().decode()
saved = failed = 0
for line in content.strip().split("\n"):
if not line.strip():
continue
try:
item = json.loads(line)
custom_id = item["custom_id"]
doc_id, brief_type = _parse_custom_id(custom_id)
if item.get("error"):
logger.warning(f"Batch result error for {custom_id}: {item['error']}")
failed += 1
continue
raw = item["response"]["body"]["choices"][0]["message"]["content"]
brief = parse_brief_json(raw, "openai", model)
doc = db.get(BillDocument, doc_id)
if not doc:
failed += 1
continue
if _save_brief(db, doc_id, doc.bill_id, brief, brief_type, doc.govinfo_url):
_emit_notifications_and_news(db, doc.bill_id, brief, brief_type)
saved += 1
except Exception as exc:
logger.warning(f"Failed to process OpenAI batch result line: {exc}")
failed += 1
_clear_batch_state(db)
logger.info(f"OpenAI batch {batch_id} complete: {saved} saved, {failed} failed")
return {"status": "completed", "saved": saved, "failed": failed}
def _poll_anthropic(db, state: dict, batch_id: str, model: str) -> dict:
import anthropic
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
batch = client.messages.batches.retrieve(batch_id)
logger.info(f"Anthropic batch {batch_id} processing_status: {batch.processing_status}")
if batch.processing_status != "ended":
return {"status": "processing", "batch_status": batch.processing_status}
saved = failed = 0
for result in client.messages.batches.results(batch_id):
try:
custom_id = result.custom_id
doc_id, brief_type = _parse_custom_id(custom_id)
if result.result.type != "succeeded":
logger.warning(f"Batch result {custom_id} type: {result.result.type}")
failed += 1
continue
raw = result.result.message.content[0].text
brief = parse_brief_json(raw, "anthropic", model)
doc = db.get(BillDocument, doc_id)
if not doc:
failed += 1
continue
if _save_brief(db, doc_id, doc.bill_id, brief, brief_type, doc.govinfo_url):
_emit_notifications_and_news(db, doc.bill_id, brief, brief_type)
saved += 1
except Exception as exc:
logger.warning(f"Failed to process Anthropic batch result: {exc}")
failed += 1
_clear_batch_state(db)
logger.info(f"Anthropic batch {batch_id} complete: {saved} saved, {failed} failed")
return {"status": "completed", "saved": saved, "failed": failed}