fix(admin): LLM provider/model switching now reads DB overrides correctly
- get_llm_provider() now accepts provider + model args so DB overrides propagate through to all provider constructors (was always reading env vars, ignoring the admin UI selection) - /test-llm replaced with lightweight ping (max_tokens=20) instead of running a full bill analysis; shows model name + reply, no truncation - /api/settings/llm-models endpoint fetches available models live from each provider's API (OpenAI, Anthropic REST, Gemini, Ollama) - Admin UI model picker dynamically populated from provider API; falls back to manual text input on error; Custom model name option kept - Default Gemini model updated: gemini-1.5-pro → gemini-2.0-flash Co-Authored-By: Jack Levy
This commit is contained in:
@@ -55,32 +55,156 @@ async def update_setting(
|
||||
|
||||
|
||||
@router.post("/test-llm")
|
||||
async def test_llm_connection(current_user: User = Depends(get_current_admin)):
|
||||
"""Test that the configured LLM provider responds correctly."""
|
||||
from app.services.llm_service import get_llm_provider
|
||||
async def test_llm_connection(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Ping the configured LLM provider with a minimal request."""
|
||||
import asyncio
|
||||
prov_row = await db.get(AppSetting, "llm_provider")
|
||||
model_row = await db.get(AppSetting, "llm_model")
|
||||
provider_name = prov_row.value if prov_row else settings.LLM_PROVIDER
|
||||
model_name = model_row.value if model_row else None
|
||||
try:
|
||||
provider = get_llm_provider()
|
||||
brief = provider.generate_brief(
|
||||
doc_text="This is a test bill for connection verification purposes.",
|
||||
bill_metadata={
|
||||
"title": "Test Connection Bill",
|
||||
"sponsor_name": "Test Sponsor",
|
||||
"party": "Test",
|
||||
"state": "DC",
|
||||
"chamber": "House",
|
||||
"introduced_date": "2025-01-01",
|
||||
"latest_action_text": "Test action",
|
||||
"latest_action_date": "2025-01-01",
|
||||
},
|
||||
return await asyncio.to_thread(_ping_provider, provider_name, model_name)
|
||||
except Exception as exc:
|
||||
return {"status": "error", "detail": str(exc)}
|
||||
|
||||
|
||||
_PING = "Reply with exactly three words: Connection test successful."
|
||||
|
||||
|
||||
def _ping_provider(provider_name: str, model_name: str | None) -> dict:
|
||||
if provider_name == "openai":
|
||||
from openai import OpenAI
|
||||
model = model_name or settings.OPENAI_MODEL
|
||||
client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
resp = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": _PING}],
|
||||
max_tokens=20,
|
||||
)
|
||||
return {
|
||||
"status": "ok",
|
||||
"provider": brief.llm_provider,
|
||||
"model": brief.llm_model,
|
||||
"summary_preview": brief.summary[:100] + "..." if len(brief.summary) > 100 else brief.summary,
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "detail": str(e)}
|
||||
reply = resp.choices[0].message.content.strip()
|
||||
return {"status": "ok", "provider": "openai", "model": model, "reply": reply}
|
||||
|
||||
if provider_name == "anthropic":
|
||||
import anthropic
|
||||
model = model_name or settings.ANTHROPIC_MODEL
|
||||
client = anthropic.Anthropic(api_key=settings.ANTHROPIC_API_KEY)
|
||||
resp = client.messages.create(
|
||||
model=model,
|
||||
max_tokens=20,
|
||||
messages=[{"role": "user", "content": _PING}],
|
||||
)
|
||||
reply = resp.content[0].text.strip()
|
||||
return {"status": "ok", "provider": "anthropic", "model": model, "reply": reply}
|
||||
|
||||
if provider_name == "gemini":
|
||||
import google.generativeai as genai
|
||||
model = model_name or settings.GEMINI_MODEL
|
||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||
resp = genai.GenerativeModel(model_name=model).generate_content(_PING)
|
||||
reply = resp.text.strip()
|
||||
return {"status": "ok", "provider": "gemini", "model": model, "reply": reply}
|
||||
|
||||
if provider_name == "ollama":
|
||||
import requests as req
|
||||
model = model_name or settings.OLLAMA_MODEL
|
||||
resp = req.post(
|
||||
f"{settings.OLLAMA_BASE_URL}/api/generate",
|
||||
json={"model": model, "prompt": _PING, "stream": False},
|
||||
timeout=30,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
reply = resp.json().get("response", "").strip()
|
||||
return {"status": "ok", "provider": "ollama", "model": model, "reply": reply}
|
||||
|
||||
raise ValueError(f"Unknown provider: {provider_name}")
|
||||
|
||||
|
||||
@router.get("/llm-models")
|
||||
async def list_llm_models(
|
||||
provider: str,
|
||||
current_user: User = Depends(get_current_admin),
|
||||
):
|
||||
"""Fetch available models directly from the provider's API."""
|
||||
import asyncio
|
||||
handlers = {
|
||||
"openai": _list_openai_models,
|
||||
"anthropic": _list_anthropic_models,
|
||||
"gemini": _list_gemini_models,
|
||||
"ollama": _list_ollama_models,
|
||||
}
|
||||
fn = handlers.get(provider)
|
||||
if not fn:
|
||||
return {"models": [], "error": f"Unknown provider: {provider}"}
|
||||
try:
|
||||
return await asyncio.to_thread(fn)
|
||||
except Exception as exc:
|
||||
return {"models": [], "error": str(exc)}
|
||||
|
||||
|
||||
def _list_openai_models() -> dict:
|
||||
from openai import OpenAI
|
||||
if not settings.OPENAI_API_KEY:
|
||||
return {"models": [], "error": "OPENAI_API_KEY not configured"}
|
||||
client = OpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
all_models = client.models.list().data
|
||||
CHAT_PREFIXES = ("gpt-", "o1", "o3", "o4", "chatgpt-")
|
||||
EXCLUDE = ("realtime", "audio", "tts", "whisper", "embedding", "dall-e", "instruct")
|
||||
filtered = sorted(
|
||||
[m.id for m in all_models
|
||||
if any(m.id.startswith(p) for p in CHAT_PREFIXES)
|
||||
and not any(x in m.id for x in EXCLUDE)],
|
||||
reverse=True,
|
||||
)
|
||||
return {"models": [{"id": m, "name": m} for m in filtered]}
|
||||
|
||||
|
||||
def _list_anthropic_models() -> dict:
|
||||
import requests as req
|
||||
if not settings.ANTHROPIC_API_KEY:
|
||||
return {"models": [], "error": "ANTHROPIC_API_KEY not configured"}
|
||||
resp = req.get(
|
||||
"https://api.anthropic.com/v1/models",
|
||||
headers={
|
||||
"x-api-key": settings.ANTHROPIC_API_KEY,
|
||||
"anthropic-version": "2023-06-01",
|
||||
},
|
||||
timeout=10,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
return {
|
||||
"models": [
|
||||
{"id": m["id"], "name": m.get("display_name", m["id"])}
|
||||
for m in data.get("data", [])
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def _list_gemini_models() -> dict:
|
||||
import google.generativeai as genai
|
||||
if not settings.GEMINI_API_KEY:
|
||||
return {"models": [], "error": "GEMINI_API_KEY not configured"}
|
||||
genai.configure(api_key=settings.GEMINI_API_KEY)
|
||||
models = [
|
||||
{"id": m.name.replace("models/", ""), "name": m.display_name}
|
||||
for m in genai.list_models()
|
||||
if "generateContent" in m.supported_generation_methods
|
||||
]
|
||||
return {"models": sorted(models, key=lambda x: x["id"])}
|
||||
|
||||
|
||||
def _list_ollama_models() -> dict:
|
||||
import requests as req
|
||||
try:
|
||||
resp = req.get(f"{settings.OLLAMA_BASE_URL}/api/tags", timeout=5)
|
||||
resp.raise_for_status()
|
||||
tags = resp.json().get("models", [])
|
||||
return {"models": [{"id": m["name"], "name": m["name"]} for m in tags]}
|
||||
except Exception as exc:
|
||||
return {"models": [], "error": f"Ollama unreachable: {exc}"}
|
||||
|
||||
|
||||
def _current_model(provider: str) -> str:
|
||||
|
||||
Reference in New Issue
Block a user