fix: backfill_brief_labels bulk SQL runs before ORM load to prevent session flush race
Quoteless unlabeled points (old-format briefs with no citation system) were being auto-labeled via raw SQL after db.get() loaded them into the session identity map. SQLAlchemy's commit-time flush could re-emit the ORM object's cached state, silently overwriting the raw UPDATE. Fix: run a single bulk SQL UPDATE for all matching rows before any ORM objects are loaded into the session. The commit is then a clean single-statement transaction with nothing to interfere. LLM classification of quoted points continues in a separate pass with normal flag_modified + commit. Authored by: Jack Levy
This commit is contained in:
@@ -203,6 +203,37 @@ def backfill_brief_labels(self):
|
||||
|
||||
db = get_sync_db()
|
||||
try:
|
||||
# Step 1: Bulk auto-label quoteless unlabeled points as "inference" via raw SQL.
|
||||
# This runs before any ORM objects are loaded so the session identity map cannot
|
||||
# interfere with the commit (the classic "ORM flush overwrites raw UPDATE" trap).
|
||||
_BULK_AUTO_LABEL = """
|
||||
UPDATE bill_briefs SET {col} = (
|
||||
SELECT jsonb_agg(
|
||||
CASE
|
||||
WHEN jsonb_typeof(p) = 'object'
|
||||
AND (p->>'label') IS NULL
|
||||
AND (p->>'quote') IS NULL
|
||||
THEN p || '{{"label":"inference"}}'
|
||||
ELSE p
|
||||
END
|
||||
)
|
||||
FROM jsonb_array_elements({col}) AS p
|
||||
)
|
||||
WHERE {col} IS NOT NULL AND EXISTS (
|
||||
SELECT 1 FROM jsonb_array_elements({col}) AS p
|
||||
WHERE jsonb_typeof(p) = 'object'
|
||||
AND (p->>'label') IS NULL
|
||||
AND (p->>'quote') IS NULL
|
||||
)
|
||||
"""
|
||||
auto_rows = 0
|
||||
for col in ("key_points", "risks"):
|
||||
result = db.execute(text(_BULK_AUTO_LABEL.format(col=col)))
|
||||
auto_rows += result.rowcount
|
||||
db.commit()
|
||||
logger.info(f"backfill_brief_labels: bulk auto-labeled {auto_rows} rows (quoteless → inference)")
|
||||
|
||||
# Step 2: Find briefs that still have unlabeled points (must have quotes → need LLM).
|
||||
unlabeled_ids = db.execute(text("""
|
||||
SELECT id FROM bill_briefs
|
||||
WHERE (
|
||||
@@ -235,11 +266,11 @@ def backfill_brief_labels(self):
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Collect all unlabeled cited points across both fields
|
||||
# Only points with a quote can be LLM-classified as cited_fact vs inference
|
||||
to_classify: list[tuple[str, int, dict]] = []
|
||||
for field_name in ("key_points", "risks"):
|
||||
for i, p in enumerate(getattr(brief, field_name) or []):
|
||||
if isinstance(p, dict) and p.get("label") is None:
|
||||
if isinstance(p, dict) and p.get("label") is None and p.get("quote"):
|
||||
to_classify.append((field_name, i, p))
|
||||
|
||||
if not to_classify:
|
||||
@@ -289,10 +320,10 @@ def backfill_brief_labels(self):
|
||||
time.sleep(0.2)
|
||||
|
||||
logger.info(
|
||||
f"backfill_brief_labels: {total} briefs found, "
|
||||
f"backfill_brief_labels: {total} briefs needing LLM, "
|
||||
f"{updated} updated, {skipped} skipped"
|
||||
)
|
||||
return {"total": total, "updated": updated, "skipped": skipped}
|
||||
return {"auto_labeled_rows": auto_rows, "total_llm": total, "updated": updated, "skipped": skipped}
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user