328 lines
13 KiB
Python
328 lines
13 KiB
Python
"""
|
|
LLM provider abstraction.
|
|
|
|
All providers implement generate_brief(doc_text, bill_metadata) -> ReverseBrief.
|
|
Select provider via LLM_PROVIDER env var.
|
|
"""
|
|
import json
|
|
import logging
|
|
import re
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass, field
|
|
|
|
from app.config import settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
SYSTEM_PROMPT = """You are a nonpartisan legislative analyst specializing in translating complex \
|
|
legislation into clear, accurate summaries for informed citizens. You analyze bills objectively \
|
|
without political bias.
|
|
|
|
Always respond with valid JSON matching exactly this schema:
|
|
{
|
|
"summary": "2-4 paragraph plain-language summary of what this bill does",
|
|
"key_points": ["specific concrete fact 1", "specific concrete fact 2"],
|
|
"risks": ["legitimate concern or challenge 1", "legitimate concern 2"],
|
|
"deadlines": [{"date": "YYYY-MM-DD or null", "description": "what happens on this date"}],
|
|
"topic_tags": ["healthcare", "taxation"]
|
|
}
|
|
|
|
Rules:
|
|
- summary: Explain WHAT the bill does, not whether it is good or bad. Be factual and complete.
|
|
- key_points: 5-10 specific, concrete things the bill changes, authorizes, or appropriates.
|
|
- risks: Legitimate concerns from any perspective — costs, implementation challenges, \
|
|
constitutional questions, unintended consequences. Include at least 2 even for benign bills.
|
|
- deadlines: Only include if explicitly stated in the text. Use null for date if a deadline \
|
|
is mentioned without a specific date. Empty list if none.
|
|
- topic_tags: 3-8 lowercase tags. Prefer these standard tags: healthcare, taxation, defense, \
|
|
education, immigration, environment, housing, infrastructure, technology, agriculture, judiciary, \
|
|
foreign-policy, veterans, social-security, trade, budget, energy, banking, transportation, \
|
|
public-lands, labor, civil-rights, science.
|
|
|
|
Respond with ONLY valid JSON. No preamble, no explanation, no markdown code blocks."""
|
|
|
|
MAX_TOKENS_DEFAULT = 6000
|
|
MAX_TOKENS_OLLAMA = 3000
|
|
TOKENS_PER_CHAR = 0.25 # rough approximation: 4 chars ≈ 1 token
|
|
|
|
|
|
@dataclass
|
|
class ReverseBrief:
|
|
summary: str
|
|
key_points: list[str]
|
|
risks: list[str]
|
|
deadlines: list[dict]
|
|
topic_tags: list[str]
|
|
llm_provider: str
|
|
llm_model: str
|
|
|
|
|
|
def smart_truncate(text: str, max_tokens: int) -> str:
|
|
"""Truncate bill text intelligently if it exceeds token budget."""
|
|
approx_tokens = len(text) * TOKENS_PER_CHAR
|
|
if approx_tokens <= max_tokens:
|
|
return text
|
|
|
|
# Keep first 75% of budget for the preamble (purpose section)
|
|
# and last 25% for effective dates / enforcement sections
|
|
preamble_chars = int(max_tokens * 0.75 / TOKENS_PER_CHAR)
|
|
tail_chars = int(max_tokens * 0.25 / TOKENS_PER_CHAR)
|
|
omitted_chars = len(text) - preamble_chars - tail_chars
|
|
|
|
return (
|
|
text[:preamble_chars]
|
|
+ f"\n\n[... {omitted_chars:,} characters omitted for length ...]\n\n"
|
|
+ text[-tail_chars:]
|
|
)
|
|
|
|
|
|
AMENDMENT_SYSTEM_PROMPT = """You are a nonpartisan legislative analyst. A bill has been updated \
|
|
and you must summarize what changed between the previous and new version.
|
|
|
|
Always respond with valid JSON matching exactly this schema:
|
|
{
|
|
"summary": "2-3 paragraph plain-language description of what changed in this version",
|
|
"key_points": ["specific change 1", "specific change 2"],
|
|
"risks": ["new concern introduced by this change 1", "concern 2"],
|
|
"deadlines": [{"date": "YYYY-MM-DD or null", "description": "new deadline added"}],
|
|
"topic_tags": ["healthcare", "taxation"]
|
|
}
|
|
|
|
Rules:
|
|
- summary: Focus ONLY on what is different from the previous version. Be specific.
|
|
- key_points: List concrete additions, removals, or modifications in this version.
|
|
- risks: Only include risks that are new or changed relative to the previous version.
|
|
- deadlines: Only new or changed deadlines. Empty list if none.
|
|
- topic_tags: Same standard tags as before — include any new topics this version adds.
|
|
|
|
Respond with ONLY valid JSON. No preamble, no explanation, no markdown code blocks."""
|
|
|
|
|
|
def build_amendment_prompt(new_text: str, previous_text: str, bill_metadata: dict, max_tokens: int) -> str:
|
|
half = max_tokens // 2
|
|
truncated_new = smart_truncate(new_text, half)
|
|
truncated_prev = smart_truncate(previous_text, half)
|
|
return f"""A bill has been updated. Summarize what changed between the previous and new version.
|
|
|
|
BILL METADATA:
|
|
- Title: {bill_metadata.get('title', 'Unknown')}
|
|
- Sponsor: {bill_metadata.get('sponsor_name', 'Unknown')} \
|
|
({bill_metadata.get('party', '?')}-{bill_metadata.get('state', '?')})
|
|
- Latest Action: {bill_metadata.get('latest_action_text', 'None')} \
|
|
({bill_metadata.get('latest_action_date', 'Unknown')})
|
|
|
|
PREVIOUS VERSION:
|
|
{truncated_prev}
|
|
|
|
NEW VERSION:
|
|
{truncated_new}
|
|
|
|
Produce the JSON amendment summary now:"""
|
|
|
|
|
|
def build_prompt(doc_text: str, bill_metadata: dict, max_tokens: int) -> str:
|
|
truncated = smart_truncate(doc_text, max_tokens)
|
|
return f"""Analyze this legislation and produce a structured brief.
|
|
|
|
BILL METADATA:
|
|
- Title: {bill_metadata.get('title', 'Unknown')}
|
|
- Sponsor: {bill_metadata.get('sponsor_name', 'Unknown')} \
|
|
({bill_metadata.get('party', '?')}-{bill_metadata.get('state', '?')})
|
|
- Introduced: {bill_metadata.get('introduced_date', 'Unknown')}
|
|
- Chamber: {bill_metadata.get('chamber', 'Unknown')}
|
|
- Latest Action: {bill_metadata.get('latest_action_text', 'None')} \
|
|
({bill_metadata.get('latest_action_date', 'Unknown')})
|
|
|
|
BILL TEXT:
|
|
{truncated}
|
|
|
|
Produce the JSON brief now:"""
|
|
|
|
|
|
def parse_brief_json(raw: str | dict, provider: str, model: str) -> ReverseBrief:
|
|
"""Parse and validate LLM JSON response into a ReverseBrief."""
|
|
if isinstance(raw, str):
|
|
# Strip markdown code fences if present
|
|
raw = re.sub(r"^```(?:json)?\s*", "", raw.strip())
|
|
raw = re.sub(r"\s*```$", "", raw.strip())
|
|
data = json.loads(raw)
|
|
else:
|
|
data = raw
|
|
|
|
return ReverseBrief(
|
|
summary=str(data.get("summary", "")),
|
|
key_points=list(data.get("key_points", [])),
|
|
risks=list(data.get("risks", [])),
|
|
deadlines=list(data.get("deadlines", [])),
|
|
topic_tags=list(data.get("topic_tags", [])),
|
|
llm_provider=provider,
|
|
llm_model=model,
|
|
)
|
|
|
|
|
|
class LLMProvider(ABC):
|
|
@abstractmethod
|
|
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
pass
|
|
|
|
|
|
class OpenAIProvider(LLMProvider):
|
|
def __init__(self):
|
|
from openai import OpenAI
|
|
self.client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
|
self.model = settings.OPENAI_MODEL
|
|
|
|
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
prompt = build_prompt(doc_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{"role": "system", "content": SYSTEM_PROMPT},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
response_format={"type": "json_object"},
|
|
temperature=0.1,
|
|
)
|
|
raw = response.choices[0].message.content
|
|
return parse_brief_json(raw, "openai", self.model)
|
|
|
|
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
prompt = build_amendment_prompt(new_text, previous_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
|
response = self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=[
|
|
{"role": "system", "content": AMENDMENT_SYSTEM_PROMPT},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
response_format={"type": "json_object"},
|
|
temperature=0.1,
|
|
)
|
|
raw = response.choices[0].message.content
|
|
return parse_brief_json(raw, "openai", self.model)
|
|
|
|
|
|
class AnthropicProvider(LLMProvider):
|
|
def __init__(self):
|
|
import anthropic
|
|
self.client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
|
self.model = settings.ANTHROPIC_MODEL
|
|
|
|
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
prompt = build_prompt(doc_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
|
response = self.client.messages.create(
|
|
model=self.model,
|
|
max_tokens=4096,
|
|
system=SYSTEM_PROMPT + "\n\nIMPORTANT: Respond with ONLY valid JSON. No other text.",
|
|
messages=[{"role": "user", "content": prompt}],
|
|
)
|
|
raw = response.content[0].text
|
|
return parse_brief_json(raw, "anthropic", self.model)
|
|
|
|
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
prompt = build_amendment_prompt(new_text, previous_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
|
response = self.client.messages.create(
|
|
model=self.model,
|
|
max_tokens=4096,
|
|
system=AMENDMENT_SYSTEM_PROMPT + "\n\nIMPORTANT: Respond with ONLY valid JSON. No other text.",
|
|
messages=[{"role": "user", "content": prompt}],
|
|
)
|
|
raw = response.content[0].text
|
|
return parse_brief_json(raw, "anthropic", self.model)
|
|
|
|
|
|
class GeminiProvider(LLMProvider):
|
|
def __init__(self):
|
|
import google.generativeai as genai
|
|
genai.configure(api_key=settings.GEMINI_API_KEY)
|
|
self._genai = genai
|
|
self.model_name = settings.GEMINI_MODEL
|
|
|
|
def _make_model(self, system_prompt: str):
|
|
return self._genai.GenerativeModel(
|
|
model_name=self.model_name,
|
|
generation_config={"response_mime_type": "application/json", "temperature": 0.1},
|
|
system_instruction=system_prompt,
|
|
)
|
|
|
|
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
prompt = build_prompt(doc_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
|
response = self._make_model(SYSTEM_PROMPT).generate_content(prompt)
|
|
return parse_brief_json(response.text, "gemini", self.model_name)
|
|
|
|
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
prompt = build_amendment_prompt(new_text, previous_text, bill_metadata, MAX_TOKENS_DEFAULT)
|
|
response = self._make_model(AMENDMENT_SYSTEM_PROMPT).generate_content(prompt)
|
|
return parse_brief_json(response.text, "gemini", self.model_name)
|
|
|
|
|
|
class OllamaProvider(LLMProvider):
|
|
def __init__(self):
|
|
self.base_url = settings.OLLAMA_BASE_URL.rstrip("/")
|
|
self.model = settings.OLLAMA_MODEL
|
|
|
|
def _generate(self, system_prompt: str, user_prompt: str) -> str:
|
|
import requests as req
|
|
full_prompt = f"{system_prompt}\n\n{user_prompt}"
|
|
response = req.post(
|
|
f"{self.base_url}/api/generate",
|
|
json={"model": self.model, "prompt": full_prompt, "stream": False, "format": "json"},
|
|
timeout=300,
|
|
)
|
|
response.raise_for_status()
|
|
raw = response.json().get("response", "")
|
|
try:
|
|
return raw
|
|
except Exception:
|
|
strict = f"{full_prompt}\n\nCRITICAL: Your response MUST be valid JSON only."
|
|
r2 = req.post(
|
|
f"{self.base_url}/api/generate",
|
|
json={"model": self.model, "prompt": strict, "stream": False, "format": "json"},
|
|
timeout=300,
|
|
)
|
|
r2.raise_for_status()
|
|
return r2.json().get("response", "")
|
|
|
|
def generate_brief(self, doc_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
prompt = build_prompt(doc_text, bill_metadata, MAX_TOKENS_OLLAMA)
|
|
raw = self._generate(SYSTEM_PROMPT, prompt)
|
|
try:
|
|
return parse_brief_json(raw, "ollama", self.model)
|
|
except (json.JSONDecodeError, KeyError) as e:
|
|
logger.warning(f"Ollama JSON parse failed, retrying: {e}")
|
|
raw2 = self._generate(
|
|
SYSTEM_PROMPT,
|
|
prompt + "\n\nCRITICAL: Your response MUST be valid JSON only. No text before or after the JSON object."
|
|
)
|
|
return parse_brief_json(raw2, "ollama", self.model)
|
|
|
|
def generate_amendment_brief(self, new_text: str, previous_text: str, bill_metadata: dict) -> ReverseBrief:
|
|
prompt = build_amendment_prompt(new_text, previous_text, bill_metadata, MAX_TOKENS_OLLAMA)
|
|
raw = self._generate(AMENDMENT_SYSTEM_PROMPT, prompt)
|
|
try:
|
|
return parse_brief_json(raw, "ollama", self.model)
|
|
except (json.JSONDecodeError, KeyError) as e:
|
|
logger.warning(f"Ollama amendment JSON parse failed, retrying: {e}")
|
|
raw2 = self._generate(
|
|
AMENDMENT_SYSTEM_PROMPT,
|
|
prompt + "\n\nCRITICAL: Your response MUST be valid JSON only. No text before or after the JSON object."
|
|
)
|
|
return parse_brief_json(raw2, "ollama", self.model)
|
|
|
|
|
|
def get_llm_provider() -> LLMProvider:
|
|
"""Factory — returns the configured LLM provider."""
|
|
provider = settings.LLM_PROVIDER.lower()
|
|
if provider == "openai":
|
|
return OpenAIProvider()
|
|
elif provider == "anthropic":
|
|
return AnthropicProvider()
|
|
elif provider == "gemini":
|
|
return GeminiProvider()
|
|
elif provider == "ollama":
|
|
return OllamaProvider()
|
|
raise ValueError(f"Unknown LLM_PROVIDER: '{provider}'. Must be one of: openai, anthropic, gemini, ollama")
|