""" LLM Batch processor — submits and polls OpenAI/Anthropic Batch API jobs. 50% cheaper than synchronous calls; 24-hour processing window. New bills still use the synchronous llm_processor task. """ import io import json import logging from datetime import datetime from sqlalchemy import text from app.config import settings from app.database import get_sync_db from app.models import Bill, BillBrief, BillDocument, Member from app.models.setting import AppSetting from app.services.llm_service import ( AMENDMENT_SYSTEM_PROMPT, MAX_TOKENS_DEFAULT, SYSTEM_PROMPT, build_amendment_prompt, build_prompt, parse_brief_json, ) from app.workers.celery_app import celery_app logger = logging.getLogger(__name__) _BATCH_SETTING_KEY = "llm_active_batch" # ── State helpers ────────────────────────────────────────────────────────────── def _save_batch_state(db, state: dict): row = db.get(AppSetting, _BATCH_SETTING_KEY) if row: row.value = json.dumps(state) else: row = AppSetting(key=_BATCH_SETTING_KEY, value=json.dumps(state)) db.add(row) db.commit() def _clear_batch_state(db): row = db.get(AppSetting, _BATCH_SETTING_KEY) if row: db.delete(row) db.commit() # ── Request builder ──────────────────────────────────────────────────────────── def _build_request_data(db, doc_id: int, bill_id: str) -> tuple[str, str, str]: """Returns (custom_id, system_prompt, user_prompt) for a document.""" doc = db.get(BillDocument, doc_id) if not doc or not doc.raw_text: raise ValueError(f"Document {doc_id} missing or has no text") bill = db.get(Bill, bill_id) if not bill: raise ValueError(f"Bill {bill_id} not found") sponsor = db.get(Member, bill.sponsor_id) if bill.sponsor_id else None bill_metadata = { "title": bill.title or "Unknown Title", "sponsor_name": sponsor.name if sponsor else "Unknown", "party": sponsor.party if sponsor else "Unknown", "state": sponsor.state if sponsor else "Unknown", "chamber": bill.chamber or "Unknown", "introduced_date": str(bill.introduced_date) if bill.introduced_date else "Unknown", "latest_action_text": bill.latest_action_text or "None", "latest_action_date": str(bill.latest_action_date) if bill.latest_action_date else "Unknown", } previous_full_brief = ( db.query(BillBrief) .filter_by(bill_id=bill_id, brief_type="full") .order_by(BillBrief.created_at.desc()) .first() ) if previous_full_brief and previous_full_brief.document_id: previous_doc = db.get(BillDocument, previous_full_brief.document_id) if previous_doc and previous_doc.raw_text: brief_type = "amendment" prompt = build_amendment_prompt(doc.raw_text, previous_doc.raw_text, bill_metadata, MAX_TOKENS_DEFAULT) system_prompt = AMENDMENT_SYSTEM_PROMPT + "\n\nIMPORTANT: Respond with ONLY valid JSON. No other text." else: brief_type = "full" prompt = build_prompt(doc.raw_text, bill_metadata, MAX_TOKENS_DEFAULT) system_prompt = SYSTEM_PROMPT else: brief_type = "full" prompt = build_prompt(doc.raw_text, bill_metadata, MAX_TOKENS_DEFAULT) system_prompt = SYSTEM_PROMPT custom_id = f"doc-{doc_id}-{brief_type}" return custom_id, system_prompt, prompt # ── Submit task ──────────────────────────────────────────────────────────────── @celery_app.task(bind=True, name="app.workers.llm_batch_processor.submit_llm_batch") def submit_llm_batch(self): """Submit all unbriefed documents to the OpenAI or Anthropic Batch API.""" db = get_sync_db() try: prov_row = db.get(AppSetting, "llm_provider") model_row = db.get(AppSetting, "llm_model") provider_name = ((prov_row.value if prov_row else None) or settings.LLM_PROVIDER).lower() if provider_name not in ("openai", "anthropic"): return {"status": "unsupported", "provider": provider_name} # Check for already-active batch active_row = db.get(AppSetting, _BATCH_SETTING_KEY) if active_row: try: active = json.loads(active_row.value) if active.get("status") == "processing": return {"status": "already_active", "batch_id": active.get("batch_id")} except Exception: pass # Find docs with text but no brief rows = db.execute(text(""" SELECT bd.id AS doc_id, bd.bill_id, bd.govinfo_url FROM bill_documents bd LEFT JOIN bill_briefs bb ON bb.document_id = bd.id WHERE bd.raw_text IS NOT NULL AND bb.id IS NULL LIMIT 1000 """)).fetchall() if not rows: return {"status": "nothing_to_process"} doc_ids = [r.doc_id for r in rows] if provider_name == "openai": model = (model_row.value if model_row else None) or settings.OPENAI_MODEL batch_id = _submit_openai_batch(db, rows, model) else: model = (model_row.value if model_row else None) or settings.ANTHROPIC_MODEL batch_id = _submit_anthropic_batch(db, rows, model) state = { "batch_id": batch_id, "provider": provider_name, "model": model, "doc_ids": doc_ids, "doc_count": len(doc_ids), "submitted_at": datetime.utcnow().isoformat(), "status": "processing", } _save_batch_state(db, state) logger.info(f"Submitted {len(doc_ids)}-doc batch to {provider_name}: {batch_id}") return {"status": "submitted", "batch_id": batch_id, "doc_count": len(doc_ids)} finally: db.close() def _submit_openai_batch(db, rows, model: str) -> str: from openai import OpenAI client = OpenAI(api_key=settings.OPENAI_API_KEY) lines = [] for row in rows: try: custom_id, system_prompt, prompt = _build_request_data(db, row.doc_id, row.bill_id) except Exception as exc: logger.warning(f"Skipping doc {row.doc_id}: {exc}") continue lines.append(json.dumps({ "custom_id": custom_id, "method": "POST", "url": "/v1/chat/completions", "body": { "model": model, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ], "response_format": {"type": "json_object"}, "temperature": 0.1, "max_tokens": MAX_TOKENS_DEFAULT, }, })) jsonl_bytes = "\n".join(lines).encode() file_obj = client.files.create( file=("batch.jsonl", io.BytesIO(jsonl_bytes), "application/jsonl"), purpose="batch", ) batch = client.batches.create( input_file_id=file_obj.id, endpoint="/v1/chat/completions", completion_window="24h", ) return batch.id def _submit_anthropic_batch(db, rows, model: str) -> str: import anthropic client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY) requests = [] for row in rows: try: custom_id, system_prompt, prompt = _build_request_data(db, row.doc_id, row.bill_id) except Exception as exc: logger.warning(f"Skipping doc {row.doc_id}: {exc}") continue requests.append({ "custom_id": custom_id, "params": { "model": model, "max_tokens": 4096, "system": [{"type": "text", "text": system_prompt, "cache_control": {"type": "ephemeral"}}], "messages": [{"role": "user", "content": prompt}], }, }) batch = client.messages.batches.create(requests=requests) return batch.id # ── Poll task ────────────────────────────────────────────────────────────────── @celery_app.task(bind=True, name="app.workers.llm_batch_processor.poll_llm_batch_results") def poll_llm_batch_results(self): """Check active batch status and import completed results (runs every 30 min via beat).""" db = get_sync_db() try: active_row = db.get(AppSetting, _BATCH_SETTING_KEY) if not active_row: return {"status": "no_active_batch"} try: state = json.loads(active_row.value) except Exception: _clear_batch_state(db) return {"status": "invalid_state"} batch_id = state["batch_id"] provider_name = state["provider"] model = state["model"] if provider_name == "openai": return _poll_openai(db, state, batch_id, model) elif provider_name == "anthropic": return _poll_anthropic(db, state, batch_id, model) else: _clear_batch_state(db) return {"status": "unknown_provider"} finally: db.close() # ── Result processing helpers ────────────────────────────────────────────────── def _save_brief(db, doc_id: int, bill_id: str, brief, brief_type: str, govinfo_url) -> bool: """Idempotency check + save. Returns True if saved, False if already exists.""" if db.query(BillBrief).filter_by(document_id=doc_id).first(): return False db_brief = BillBrief( bill_id=bill_id, document_id=doc_id, brief_type=brief_type, summary=brief.summary, key_points=brief.key_points, risks=brief.risks, deadlines=brief.deadlines, topic_tags=brief.topic_tags, llm_provider=brief.llm_provider, llm_model=brief.llm_model, govinfo_url=govinfo_url, ) db.add(db_brief) db.commit() db.refresh(db_brief) return True def _emit_notifications_and_news(db, bill_id: str, brief, brief_type: str): bill = db.get(Bill, bill_id) if not bill: return from app.workers.notification_utils import ( emit_bill_notification, emit_member_follow_notifications, emit_topic_follow_notifications, ) event_type = "new_amendment" if brief_type == "amendment" else "new_document" emit_bill_notification(db, bill, event_type, brief.summary) emit_member_follow_notifications(db, bill, event_type, brief.summary) emit_topic_follow_notifications(db, bill, event_type, brief.summary, brief.topic_tags or []) from app.workers.news_fetcher import fetch_news_for_bill fetch_news_for_bill.delay(bill_id) def _parse_custom_id(custom_id: str) -> tuple[int, str]: """Parse 'doc-{doc_id}-{brief_type}' → (doc_id, brief_type).""" parts = custom_id.split("-") return int(parts[1]), parts[2] def _poll_openai(db, state: dict, batch_id: str, model: str) -> dict: from openai import OpenAI client = OpenAI(api_key=settings.OPENAI_API_KEY) batch = client.batches.retrieve(batch_id) logger.info(f"OpenAI batch {batch_id} status: {batch.status}") if batch.status in ("failed", "cancelled", "expired"): _clear_batch_state(db) return {"status": batch.status} if batch.status != "completed": return {"status": "processing", "batch_status": batch.status} content = client.files.content(batch.output_file_id).read().decode() saved = failed = 0 for line in content.strip().split("\n"): if not line.strip(): continue try: item = json.loads(line) custom_id = item["custom_id"] doc_id, brief_type = _parse_custom_id(custom_id) if item.get("error"): logger.warning(f"Batch result error for {custom_id}: {item['error']}") failed += 1 continue raw = item["response"]["body"]["choices"][0]["message"]["content"] brief = parse_brief_json(raw, "openai", model) doc = db.get(BillDocument, doc_id) if not doc: failed += 1 continue if _save_brief(db, doc_id, doc.bill_id, brief, brief_type, doc.govinfo_url): _emit_notifications_and_news(db, doc.bill_id, brief, brief_type) saved += 1 except Exception as exc: logger.warning(f"Failed to process OpenAI batch result line: {exc}") failed += 1 _clear_batch_state(db) logger.info(f"OpenAI batch {batch_id} complete: {saved} saved, {failed} failed") return {"status": "completed", "saved": saved, "failed": failed} def _poll_anthropic(db, state: dict, batch_id: str, model: str) -> dict: import anthropic client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY) batch = client.messages.batches.retrieve(batch_id) logger.info(f"Anthropic batch {batch_id} processing_status: {batch.processing_status}") if batch.processing_status != "ended": return {"status": "processing", "batch_status": batch.processing_status} saved = failed = 0 for result in client.messages.batches.results(batch_id): try: custom_id = result.custom_id doc_id, brief_type = _parse_custom_id(custom_id) if result.result.type != "succeeded": logger.warning(f"Batch result {custom_id} type: {result.result.type}") failed += 1 continue raw = result.result.message.content[0].text brief = parse_brief_json(raw, "anthropic", model) doc = db.get(BillDocument, doc_id) if not doc: failed += 1 continue if _save_brief(db, doc_id, doc.bill_id, brief, brief_type, doc.govinfo_url): _emit_notifications_and_news(db, doc.bill_id, brief, brief_type) saved += 1 except Exception as exc: logger.warning(f"Failed to process Anthropic batch result: {exc}") failed += 1 _clear_batch_state(db) logger.info(f"Anthropic batch {batch_id} complete: {saved} saved, {failed} failed") return {"status": "completed", "saved": saved, "failed": failed}