enrich_papers.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #!/usr/bin/env python3
  2. """Translate abstracts, generate tags, and produce short explanations."""
  3. from __future__ import annotations
  4. import argparse
  5. import json
  6. from typing import Any, Dict, List
  7. from fetch_arxiv import DOMAIN_CONFIGS
  8. from utils import log, normalize_space, ollama_generate_json, read_json, truncate, write_json
  9. FALLBACK_TAGS = {
  10. "embodied": ["具身智能", "机器人", "真实部署", "操控", "导航"],
  11. "representation": ["表征学习", "潜在空间", "世界模型", "预训练", "对象中心"],
  12. "reinforcement": ["强化学习", "策略优化", "奖励设计", "离线 RL", "模仿学习"],
  13. }
  14. def build_prompt(paper: Dict[str, Any]) -> str:
  15. domain_label = DOMAIN_CONFIGS[paper["domain"]]["label_zh"]
  16. return f"""
  17. 你是 RobotDaily 的论文晨报编辑。请根据给定的英文标题与英文摘要,输出严格 JSON。
  18. 只输出一个 JSON 对象,结构如下:
  19. {{
  20. "translated_abstract_zh": "...",
  21. "brief_explanation_zh": "...",
  22. "tags": ["标签 1", "标签 2", "标签 3", "标签 4", "标签 5"]
  23. }}
  24. 要求:
  25. 1. translated_abstract_zh:忠实翻译原摘要,不要增加原文没有的实验结果;控制在 180-400 个中文字符,必须完整覆盖原文摘要的所有要点。
  26. 2. brief_explanation_zh:40-90 个中文字符,说明为什么值得读,尽量偏应用价值和创新点。
  27. 3. tags:给 4-6 个适合直接贴在移动端卡片上的简短标签;尽量用中文,必要时保留通用英文术语,如 World Model、Offline RL。
  28. 4. 语气务实、技术导向,不要夸张。
  29. 5. 不要输出 Markdown,不要输出代码块。
  30. 领域:{domain_label}
  31. 标题:{paper['title']}
  32. 英文摘要:{paper['summary']}
  33. """.strip()
  34. def fallback_enrichment(paper: Dict[str, Any]) -> Dict[str, Any]:
  35. tags = FALLBACK_TAGS.get(paper["domain"], ["AI 论文", "机器学习", "应用研究"])
  36. title = paper.get("title", "")
  37. summary = paper.get("summary", "")
  38. # 从标题提取核心方法
  39. brief = title.split(':')[0].strip()
  40. if '.' in brief:
  41. brief = brief.split('.')[0].strip()
  42. # 判断方法类型
  43. summary_lower = summary.lower()
  44. if "diffusion" in summary_lower:
  45. method = "扩散模型"
  46. elif "reinforcement learning" in summary_lower:
  47. method = "强化学习"
  48. elif "imitation learning" in summary_lower:
  49. method = "模仿学习"
  50. elif "contrastive" in summary_lower:
  51. method = "对比学习"
  52. elif "transformer" in summary_lower:
  53. method = "Transformer"
  54. elif "self-supervised" in summary_lower:
  55. method = "自监督学习"
  56. elif "representation" in summary_lower:
  57. method = "表征学习"
  58. else:
  59. method = "多种技术"
  60. # 判断应用领域
  61. if "robot" in summary_lower or "manipulation" in summary_lower:
  62. field = "机器人操作"
  63. elif "navigation" in summary_lower or "driving" in summary_lower:
  64. field = "导航控制"
  65. elif "translation" in summary_lower or "generation" in summary_lower:
  66. field = "生成任务"
  67. else:
  68. field = "相关任务"
  69. # 判断结果
  70. if "real-world" in summary_lower or "deployment" in summary_lower:
  71. result = "真实部署"
  72. elif "zero-shot" in summary_lower:
  73. result = "零样本泛化"
  74. elif "first" in paper.get("matched_innovation_terms", []) or "novel" in paper.get("matched_innovation_terms", []):
  75. result = "首次提出"
  76. elif "improve" in summary_lower or "better" in summary_lower:
  77. result = "性能提升"
  78. else:
  79. result = "性能优化"
  80. brief = f"{brief},采用{method}解决{field},实现{result}"
  81. return {
  82. "translated_abstract_zh": f"【LLM 暂不可用,先保留英文摘要要点】{truncate(summary, 220)}",
  83. "brief_explanation_zh": truncate(brief, 86),
  84. "tags": tags[:5],
  85. }
  86. def enrich_paper(paper: Dict[str, Any], model_names: List[str]) -> Dict[str, Any]:
  87. prompt = build_prompt(paper)
  88. result = None
  89. used_model = ""
  90. for model in model_names:
  91. model = normalize_space(model)
  92. if not model:
  93. continue
  94. log(f"Enriching {paper['arxiv_id']} with {model}")
  95. result = ollama_generate_json(prompt, model=model, timeout=150)
  96. if result:
  97. used_model = model
  98. break
  99. enriched = dict(paper)
  100. payload = result or fallback_enrichment(paper)
  101. tags = [normalize_space(tag).lstrip("#") for tag in payload.get("tags", []) if normalize_space(tag)]
  102. if not tags:
  103. tags = FALLBACK_TAGS.get(paper["domain"], [])[:5]
  104. enriched["translated_abstract_zh"] = normalize_space(payload.get("translated_abstract_zh", "")) or fallback_enrichment(paper)["translated_abstract_zh"]
  105. enriched["brief_explanation_zh"] = normalize_space(payload.get("brief_explanation_zh", "")) or fallback_enrichment(paper)["brief_explanation_zh"]
  106. enriched["tags"] = tags[:6]
  107. enriched["enrichment_model"] = used_model or "fallback"
  108. return enriched
  109. def enrich_selection(selection_payload: Dict[str, Any], model_names: List[str]) -> Dict[str, Any]:
  110. papers = selection_payload.get("papers", [])
  111. enriched_papers = [enrich_paper(paper, model_names=model_names) for paper in papers]
  112. by_domain: Dict[str, List[Dict[str, Any]]] = {domain: [] for domain in selection_payload.get("selected_by_domain", {})}
  113. for paper in enriched_papers:
  114. by_domain.setdefault(paper["domain"], []).append(paper)
  115. output = dict(selection_payload)
  116. output["papers"] = enriched_papers
  117. output["selected_by_domain"] = by_domain
  118. output["configured_models"] = model_names
  119. output["effective_models_used"] = list(
  120. dict.fromkeys(
  121. paper.get("enrichment_model", "")
  122. for paper in enriched_papers
  123. if paper.get("enrichment_model")
  124. )
  125. )
  126. return output
  127. def main() -> None:
  128. parser = argparse.ArgumentParser(description="Enrich RobotDaily papers with zh translation and tags")
  129. parser.add_argument("--input", required=True)
  130. parser.add_argument("--output", default="")
  131. parser.add_argument("--models", default="qwen3.5:27b")
  132. args = parser.parse_args()
  133. payload = read_json(args.input, default={}) or {}
  134. models = [item.strip() for item in args.models.split(",") if item.strip()]
  135. enriched = enrich_selection(payload, model_names=models)
  136. if args.output:
  137. write_json(args.output, enriched)
  138. else:
  139. print(json.dumps(enriched, ensure_ascii=False, indent=2))
  140. if __name__ == "__main__":
  141. main()