fix: backfill_brief_labels bulk SQL runs before ORM load to prevent session flush race

Quoteless unlabeled points (old-format briefs with no citation system) were
being auto-labeled via raw SQL after db.get() loaded them into the session
identity map. SQLAlchemy's commit-time flush could re-emit the ORM object's
cached state, silently overwriting the raw UPDATE.

Fix: run a single bulk SQL UPDATE for all matching rows before any ORM objects
are loaded into the session. The commit is then a clean single-statement
transaction with nothing to interfere. LLM classification of quoted points
continues in a separate pass with normal flag_modified + commit.

Authored by: Jack Levy
This commit is contained in:
Jack Levy
2026-03-14 19:28:33 -04:00
parent 41f6f96077
commit 5e52cf5903

View File

@@ -203,6 +203,37 @@ def backfill_brief_labels(self):
db = get_sync_db() db = get_sync_db()
try: try:
# Step 1: Bulk auto-label quoteless unlabeled points as "inference" via raw SQL.
# This runs before any ORM objects are loaded so the session identity map cannot
# interfere with the commit (the classic "ORM flush overwrites raw UPDATE" trap).
_BULK_AUTO_LABEL = """
UPDATE bill_briefs SET {col} = (
SELECT jsonb_agg(
CASE
WHEN jsonb_typeof(p) = 'object'
AND (p->>'label') IS NULL
AND (p->>'quote') IS NULL
THEN p || '{{"label":"inference"}}'
ELSE p
END
)
FROM jsonb_array_elements({col}) AS p
)
WHERE {col} IS NOT NULL AND EXISTS (
SELECT 1 FROM jsonb_array_elements({col}) AS p
WHERE jsonb_typeof(p) = 'object'
AND (p->>'label') IS NULL
AND (p->>'quote') IS NULL
)
"""
auto_rows = 0
for col in ("key_points", "risks"):
result = db.execute(text(_BULK_AUTO_LABEL.format(col=col)))
auto_rows += result.rowcount
db.commit()
logger.info(f"backfill_brief_labels: bulk auto-labeled {auto_rows} rows (quoteless → inference)")
# Step 2: Find briefs that still have unlabeled points (must have quotes → need LLM).
unlabeled_ids = db.execute(text(""" unlabeled_ids = db.execute(text("""
SELECT id FROM bill_briefs SELECT id FROM bill_briefs
WHERE ( WHERE (
@@ -235,11 +266,11 @@ def backfill_brief_labels(self):
skipped += 1 skipped += 1
continue continue
# Collect all unlabeled cited points across both fields # Only points with a quote can be LLM-classified as cited_fact vs inference
to_classify: list[tuple[str, int, dict]] = [] to_classify: list[tuple[str, int, dict]] = []
for field_name in ("key_points", "risks"): for field_name in ("key_points", "risks"):
for i, p in enumerate(getattr(brief, field_name) or []): for i, p in enumerate(getattr(brief, field_name) or []):
if isinstance(p, dict) and p.get("label") is None: if isinstance(p, dict) and p.get("label") is None and p.get("quote"):
to_classify.append((field_name, i, p)) to_classify.append((field_name, i, p))
if not to_classify: if not to_classify:
@@ -289,10 +320,10 @@ def backfill_brief_labels(self):
time.sleep(0.2) time.sleep(0.2)
logger.info( logger.info(
f"backfill_brief_labels: {total} briefs found, " f"backfill_brief_labels: {total} briefs needing LLM, "
f"{updated} updated, {skipped} skipped" f"{updated} updated, {skipped} skipped"
) )
return {"total": total, "updated": updated, "skipped": skipped} return {"auto_labeled_rows": auto_rows, "total_llm": total, "updated": updated, "skipped": skipped}
finally: finally:
db.close() db.close()