fix: backfill_brief_labels bulk SQL runs before ORM load to prevent session flush race
Quoteless unlabeled points (old-format briefs with no citation system) were being auto-labeled via raw SQL after db.get() loaded them into the session identity map. SQLAlchemy's commit-time flush could re-emit the ORM object's cached state, silently overwriting the raw UPDATE. Fix: run a single bulk SQL UPDATE for all matching rows before any ORM objects are loaded into the session. The commit is then a clean single-statement transaction with nothing to interfere. LLM classification of quoted points continues in a separate pass with normal flag_modified + commit. Authored by: Jack Levy
This commit is contained in:
@@ -203,6 +203,37 @@ def backfill_brief_labels(self):
|
|||||||
|
|
||||||
db = get_sync_db()
|
db = get_sync_db()
|
||||||
try:
|
try:
|
||||||
|
# Step 1: Bulk auto-label quoteless unlabeled points as "inference" via raw SQL.
|
||||||
|
# This runs before any ORM objects are loaded so the session identity map cannot
|
||||||
|
# interfere with the commit (the classic "ORM flush overwrites raw UPDATE" trap).
|
||||||
|
_BULK_AUTO_LABEL = """
|
||||||
|
UPDATE bill_briefs SET {col} = (
|
||||||
|
SELECT jsonb_agg(
|
||||||
|
CASE
|
||||||
|
WHEN jsonb_typeof(p) = 'object'
|
||||||
|
AND (p->>'label') IS NULL
|
||||||
|
AND (p->>'quote') IS NULL
|
||||||
|
THEN p || '{{"label":"inference"}}'
|
||||||
|
ELSE p
|
||||||
|
END
|
||||||
|
)
|
||||||
|
FROM jsonb_array_elements({col}) AS p
|
||||||
|
)
|
||||||
|
WHERE {col} IS NOT NULL AND EXISTS (
|
||||||
|
SELECT 1 FROM jsonb_array_elements({col}) AS p
|
||||||
|
WHERE jsonb_typeof(p) = 'object'
|
||||||
|
AND (p->>'label') IS NULL
|
||||||
|
AND (p->>'quote') IS NULL
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
auto_rows = 0
|
||||||
|
for col in ("key_points", "risks"):
|
||||||
|
result = db.execute(text(_BULK_AUTO_LABEL.format(col=col)))
|
||||||
|
auto_rows += result.rowcount
|
||||||
|
db.commit()
|
||||||
|
logger.info(f"backfill_brief_labels: bulk auto-labeled {auto_rows} rows (quoteless → inference)")
|
||||||
|
|
||||||
|
# Step 2: Find briefs that still have unlabeled points (must have quotes → need LLM).
|
||||||
unlabeled_ids = db.execute(text("""
|
unlabeled_ids = db.execute(text("""
|
||||||
SELECT id FROM bill_briefs
|
SELECT id FROM bill_briefs
|
||||||
WHERE (
|
WHERE (
|
||||||
@@ -235,11 +266,11 @@ def backfill_brief_labels(self):
|
|||||||
skipped += 1
|
skipped += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Collect all unlabeled cited points across both fields
|
# Only points with a quote can be LLM-classified as cited_fact vs inference
|
||||||
to_classify: list[tuple[str, int, dict]] = []
|
to_classify: list[tuple[str, int, dict]] = []
|
||||||
for field_name in ("key_points", "risks"):
|
for field_name in ("key_points", "risks"):
|
||||||
for i, p in enumerate(getattr(brief, field_name) or []):
|
for i, p in enumerate(getattr(brief, field_name) or []):
|
||||||
if isinstance(p, dict) and p.get("label") is None:
|
if isinstance(p, dict) and p.get("label") is None and p.get("quote"):
|
||||||
to_classify.append((field_name, i, p))
|
to_classify.append((field_name, i, p))
|
||||||
|
|
||||||
if not to_classify:
|
if not to_classify:
|
||||||
@@ -289,10 +320,10 @@ def backfill_brief_labels(self):
|
|||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"backfill_brief_labels: {total} briefs found, "
|
f"backfill_brief_labels: {total} briefs needing LLM, "
|
||||||
f"{updated} updated, {skipped} skipped"
|
f"{updated} updated, {skipped} skipped"
|
||||||
)
|
)
|
||||||
return {"total": total, "updated": updated, "skipped": skipped}
|
return {"auto_labeled_rows": auto_rows, "total_llm": total, "updated": updated, "skipped": skipped}
|
||||||
finally:
|
finally:
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user