utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. #!/usr/bin/env python3
  2. """Shared helpers for the RobotDaily arXiv digest skill."""
  3. from __future__ import annotations
  4. import json
  5. import os
  6. import re
  7. import subprocess
  8. import sys
  9. from datetime import datetime
  10. from pathlib import Path
  11. from typing import Any, Dict, Iterable, List, Optional
  12. from urllib.error import HTTPError, URLError
  13. from urllib.request import Request, urlopen
  14. from zoneinfo import ZoneInfo
  15. SKILL_DIR = Path(__file__).resolve().parents[1]
  16. ROOT_DIR = SKILL_DIR.parent
  17. DEFAULT_OUTPUT_DIR = SKILL_DIR / "output"
  18. DEFAULT_LOG_DIR = SKILL_DIR / "logs"
  19. LOCAL_TZ = ZoneInfo("Asia/Shanghai")
  20. OLLAMA_URL = os.environ.get("OLLAMA_URL", "http://127.0.0.1:11434/api/generate")
  21. def log(message: str) -> None:
  22. timestamp = datetime.now(LOCAL_TZ).strftime("%H:%M:%S")
  23. print(f"[{timestamp}] {message}", file=sys.stderr)
  24. def now_local() -> datetime:
  25. return datetime.now(LOCAL_TZ)
  26. def ensure_dir(path: Path | str) -> Path:
  27. path_obj = Path(path)
  28. path_obj.mkdir(parents=True, exist_ok=True)
  29. return path_obj
  30. def normalize_space(text: str) -> str:
  31. return re.sub(r"\s+", " ", str(text or "")).strip()
  32. def slugify(text: str) -> str:
  33. slug = re.sub(r"[^a-zA-Z0-9]+", "-", str(text or "").strip().lower()).strip("-")
  34. return slug or "digest"
  35. def canonical_arxiv_id(raw: str) -> str:
  36. text = normalize_space(raw)
  37. if not text:
  38. return ""
  39. text = text.rsplit("/", 1)[-1]
  40. text = text.replace("arXiv:", "")
  41. return re.sub(r"v\d+$", "", text)
  42. def canonical_doi(arxiv_id: str, doi: str = "") -> str:
  43. clean = normalize_space(doi)
  44. if clean:
  45. clean = clean.replace("https://doi.org/", "").replace("http://doi.org/", "")
  46. clean = clean.replace("doi:", "")
  47. return clean.strip()
  48. arxiv_clean = canonical_arxiv_id(arxiv_id)
  49. return f"10.48550/arXiv.{arxiv_clean}" if arxiv_clean else ""
  50. def canonical_doi_url(arxiv_id: str, doi: str = "") -> str:
  51. clean_doi = canonical_doi(arxiv_id, doi)
  52. return f"https://doi.org/{clean_doi}" if clean_doi else ""
  53. def build_arxiv_urls(arxiv_id: str) -> Dict[str, str]:
  54. clean = canonical_arxiv_id(arxiv_id)
  55. if not clean:
  56. return {"abs_url": "", "pdf_url": ""}
  57. return {
  58. "abs_url": f"https://arxiv.org/abs/{clean}",
  59. "pdf_url": f"https://arxiv.org/pdf/{clean}.pdf",
  60. }
  61. def read_json(path: Path | str, default: Any = None) -> Any:
  62. path_obj = Path(path)
  63. if not path_obj.exists():
  64. return default
  65. with path_obj.open("r", encoding="utf-8") as handle:
  66. return json.load(handle)
  67. def write_json(path: Path | str, data: Any) -> Path:
  68. path_obj = Path(path)
  69. ensure_dir(path_obj.parent)
  70. with path_obj.open("w", encoding="utf-8") as handle:
  71. json.dump(data, handle, ensure_ascii=False, indent=2)
  72. return path_obj
  73. def write_text(path: Path | str, content: str) -> Path:
  74. path_obj = Path(path)
  75. ensure_dir(path_obj.parent)
  76. path_obj.write_text(content, encoding="utf-8")
  77. return path_obj
  78. def load_env(env_file: Path | str | None = None) -> Dict[str, str]:
  79. env = dict(os.environ)
  80. env_path = Path(env_file) if env_file else SKILL_DIR / ".env"
  81. if env_path.exists():
  82. for line in env_path.read_text(encoding="utf-8").splitlines():
  83. line = line.strip()
  84. if not line or line.startswith("#") or "=" not in line:
  85. continue
  86. key, value = line.split("=", 1)
  87. key = key.strip()
  88. value = value.strip().strip('"').strip("'")
  89. env.setdefault(key, value)
  90. return env
  91. def extract_json_object(text: str) -> Optional[Dict[str, Any]]:
  92. if not text:
  93. return None
  94. match = re.search(r"\{.*\}", text, re.DOTALL)
  95. if not match:
  96. return None
  97. try:
  98. payload = json.loads(match.group(0))
  99. except Exception:
  100. return None
  101. return payload if isinstance(payload, dict) else None
  102. def ollama_generate_json(prompt: str, model: str, timeout: int = 120) -> Optional[Dict[str, Any]]:
  103. body = json.dumps(
  104. {
  105. "model": model,
  106. "prompt": prompt,
  107. "stream": False,
  108. "format": "json",
  109. "think": False,
  110. "options": {"temperature": 0.1, "num_predict": 800},
  111. }
  112. ).encode("utf-8")
  113. request = Request(url=OLLAMA_URL, data=body, method="POST")
  114. request.add_header("Content-Type", "application/json")
  115. try:
  116. with urlopen(request, timeout=timeout) as response:
  117. payload = json.loads(response.read().decode("utf-8", errors="ignore"))
  118. return extract_json_object(payload.get("response", ""))
  119. except HTTPError as exc:
  120. detail = ""
  121. try:
  122. detail = exc.read().decode("utf-8", errors="ignore")
  123. except Exception:
  124. detail = ""
  125. log(f"Ollama request failed: {exc} {detail}".strip())
  126. return None
  127. except (URLError, TimeoutError) as exc:
  128. log(f"Ollama request failed: {exc}")
  129. return None
  130. except Exception as exc:
  131. log(f"Ollama parse failed: {exc}")
  132. return None
  133. class CommandError(RuntimeError):
  134. pass
  135. def run_command(args: List[str], cwd: Path | str | None = None) -> subprocess.CompletedProcess[str]:
  136. result = subprocess.run(
  137. args,
  138. cwd=str(cwd) if cwd else None,
  139. capture_output=True,
  140. text=True,
  141. check=False,
  142. )
  143. if result.returncode != 0:
  144. stderr = result.stderr.strip()
  145. stdout = result.stdout.strip()
  146. detail = stderr or stdout or f"exit code {result.returncode}"
  147. raise CommandError(detail)
  148. return result
  149. def run_command_json(args: List[str], cwd: Path | str | None = None) -> Dict[str, Any]:
  150. result = run_command(args, cwd=cwd)
  151. stdout = result.stdout.strip()
  152. if not stdout:
  153. return {}
  154. try:
  155. return json.loads(stdout)
  156. except json.JSONDecodeError:
  157. start = stdout.find("{")
  158. end = stdout.rfind("}")
  159. if start != -1 and end != -1 and end > start:
  160. snippet = stdout[start : end + 1]
  161. try:
  162. return json.loads(snippet)
  163. except json.JSONDecodeError as exc:
  164. raise CommandError(f"Invalid JSON output: {exc}: {stdout[:300]}") from exc
  165. raise CommandError(f"Invalid JSON output: {stdout[:300]}")
  166. def chunk_lines(lines: Iterable[str], limit: int = 1800) -> List[str]:
  167. chunks: List[str] = []
  168. current: List[str] = []
  169. current_len = 0
  170. for line in lines:
  171. safe_line = str(line)
  172. extra = len(safe_line) + (1 if current else 0)
  173. if current and current_len + extra > limit:
  174. chunks.append("\n".join(current))
  175. current = [safe_line]
  176. current_len = len(safe_line)
  177. continue
  178. current.append(safe_line)
  179. current_len += extra
  180. if current:
  181. chunks.append("\n".join(current))
  182. return chunks
  183. def format_authors(authors: List[str], limit: int = 4) -> str:
  184. items = [normalize_space(author) for author in authors if normalize_space(author)]
  185. if len(items) <= limit:
  186. return ", ".join(items)
  187. hidden = len(items) - limit
  188. return f"{', '.join(items[:limit])} 等另外{hidden}人"
  189. def truncate(text: str, limit: int) -> str:
  190. clean = normalize_space(text)
  191. if len(clean) <= limit:
  192. return clean
  193. return clean[: limit - 1].rstrip() + "…"
  194. def html_escape(text: str) -> str:
  195. return (
  196. str(text or "")
  197. .replace("&", "&amp;")
  198. .replace("<", "&lt;")
  199. .replace(">", "&gt;")
  200. .replace('"', "&quot;")
  201. )