enrich_papers.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. #!/usr/bin/env python3
  2. """Translate abstracts, generate tags, and produce short explanations."""
  3. from __future__ import annotations
  4. import argparse
  5. import json
  6. from typing import Any, Dict, List
  7. from fetch_arxiv import DOMAIN_CONFIGS
  8. from utils import log, normalize_space, ollama_generate_json, read_json, truncate, write_json
  9. FALLBACK_TAGS = {
  10. "embodied": ["具身智能", "机器人", "真实部署", "操控", "导航"],
  11. "representation": ["表征学习", "潜在空间", "世界模型", "预训练", "对象中心"],
  12. "reinforcement": ["强化学习", "策略优化", "奖励设计", "离线 RL", "模仿学习"],
  13. }
  14. def build_prompt(paper: Dict[str, Any]) -> str:
  15. domain_label = DOMAIN_CONFIGS[paper["domain"]]["label_zh"]
  16. return f"""
  17. 你是 RobotDaily 的论文晨报编辑。请根据给定的英文标题与英文摘要,输出严格 JSON。
  18. 只输出一个 JSON 对象,结构如下:
  19. {{
  20. "translated_abstract_zh": "...",
  21. "brief_explanation_zh": "...",
  22. "tags": ["标签 1", "标签 2", "标签 3", "标签 4", "标签 5"]
  23. }}
  24. 要求:
  25. 1. translated_abstract_zh:忠实翻译原摘要,不要增加原文没有的实验结果;控制在 180-400 个中文字符,必须完整覆盖原文摘要的所有要点。
  26. 2. brief_explanation_zh:40-90 个中文字符,说明为什么值得读,尽量偏应用价值和创新点。
  27. 3. tags:给 4-6 个适合直接贴在移动端卡片上的简短标签;尽量用中文,必要时保留通用英文术语,如 World Model、Offline RL。
  28. 4. 语气务实、技术导向,不要夸张。
  29. 5. 不要输出 Markdown,不要输出代码块。
  30. 领域:{domain_label}
  31. 标题:{paper['title']}
  32. 英文摘要:{paper['summary']}
  33. """.strip()
  34. def fallback_enrichment(paper: Dict[str, Any]) -> Dict[str, Any]:
  35. title = paper.get("title", "")
  36. summary = paper.get("summary", "")
  37. summary_lower = summary.lower()
  38. # 从标题提取核心方法名,优先使用冒号前的部分
  39. if ':' in title:
  40. core_method = title.split(':')[0].strip()
  41. elif '.' in title:
  42. core_method = title.split('.')[0].strip()
  43. else:
  44. core_method = title.strip()
  45. # 如果方法名太长(超过 20 字符),使用摘要中的关键词
  46. if len(core_method) > 20:
  47. # 从摘要中提取关键词
  48. if "diffusion" in summary_lower:
  49. core_method = "扩散模型框架"
  50. elif "reinforcement learning" in summary_lower:
  51. core_method = "强化学习框架"
  52. elif "imitation learning" in summary_lower:
  53. core_method = "模仿学习框架"
  54. elif "contrastive" in summary_lower:
  55. core_method = "对比学习框架"
  56. elif "transformer" in summary_lower:
  57. core_method = "Transformer 框架"
  58. elif "self-supervised" in summary_lower:
  59. core_method = "自监督学习框架"
  60. elif "representation" in summary_lower:
  61. core_method = "表征学习框架"
  62. elif "adaptation" in summary_lower or "adaptive" in summary_lower:
  63. core_method = "自适应框架"
  64. elif "multi-agent" in summary_lower or "marl" in summary_lower:
  65. core_method = "多智能体框架"
  66. elif "world model" in summary_lower:
  67. core_method = "世界模型框架"
  68. elif "residual policy" in summary_lower:
  69. core_method = "残差策略优化"
  70. elif "preference optimization" in summary_lower:
  71. core_method = "偏好优化"
  72. else:
  73. core_method = "创新框架"
  74. # 判断方法类型(优先级从高到低)
  75. method = "多种技术"
  76. if "residual policy" in summary_lower:
  77. method = "残差策略优化"
  78. elif "preference optimization" in summary_lower:
  79. method = "偏好优化"
  80. elif "diffusion" in summary_lower:
  81. method = "扩散模型"
  82. elif "reinforcement learning" in summary_lower or "rl" in summary_lower:
  83. method = "强化学习"
  84. elif "imitation learning" in summary_lower:
  85. method = "模仿学习"
  86. elif "contrastive" in summary_lower:
  87. method = "对比学习"
  88. elif "transformer" in summary_lower:
  89. method = "Transformer"
  90. elif "self-supervised" in summary_lower or "self supervised" in summary_lower:
  91. method = "自监督学习"
  92. elif "representation learning" in summary_lower:
  93. method = "表征学习"
  94. elif "adaptation" in summary_lower or "adaptive" in summary_lower:
  95. method = "自适应方法"
  96. elif "multi-agent" in summary_lower or "marl" in summary_lower:
  97. method = "多智能体强化学习"
  98. elif "world model" in summary_lower:
  99. method = "世界模型"
  100. # 判断应用领域(优先级从高到低)
  101. field = "相关任务"
  102. if "cloth" in summary_lower or "布料" in summary_lower:
  103. field = "布料操作"
  104. elif "piano" in summary_lower or "music" in summary_lower:
  105. field = "音乐演奏"
  106. elif "racing" in summary_lower or ("autonomous" in summary_lower and ("driving" in summary_lower or "racing" in summary_lower)):
  107. field = "自动驾驶"
  108. elif "medical" in summary_lower or "delivery" in summary_lower or "logistics" in summary_lower:
  109. field = "医疗物流"
  110. elif "motion" in summary_lower or "humanoid" in summary_lower:
  111. field = "人类动作生成"
  112. elif "navigation" in summary_lower and ("robot" in summary_lower or "policy" in summary_lower):
  113. field = "机器人导航"
  114. elif "navigation" in summary_lower:
  115. field = "导航控制"
  116. elif "traffic" in summary_lower or "scene understanding" in summary_lower:
  117. field = "交通场景理解"
  118. elif "map" in summary_lower or "localization" in summary_lower or "pose estimation" in summary_lower:
  119. field = "定位建图"
  120. elif "physical systems" in summary_lower or "emulator" in summary_lower:
  121. field = "物理系统模拟"
  122. elif "robot" in summary_lower or "manipulation" in summary_lower or "dexterous" in summary_lower:
  123. field = "机器人操作"
  124. else:
  125. # 从标题推断领域
  126. title_lower = title.lower()
  127. if "robot" in title_lower or "manipulation" in title_lower:
  128. field = "机器人操作"
  129. elif "navigation" in title_lower or "driving" in title_lower:
  130. field = "导航控制"
  131. elif "piano" in title_lower or "music" in title_lower:
  132. field = "音乐演奏"
  133. elif "cloth" in title_lower:
  134. field = "布料操作"
  135. elif "motion" in title_lower or "humanoid" in title_lower:
  136. field = "人类动作生成"
  137. elif "racing" in title_lower or "autonomous" in title_lower:
  138. field = "自动驾驶"
  139. # 判断结果/创新点
  140. if "real-world" in summary_lower or "deployment" in summary_lower:
  141. result = "真实部署"
  142. elif "zero-shot" in summary_lower:
  143. result = "零样本泛化"
  144. elif "first" in paper.get("matched_innovation_terms", []) or "novel" in paper.get("matched_innovation_terms", []):
  145. result = "首次提出"
  146. elif "improve" in summary_lower or "better" in summary_lower:
  147. result = "性能提升"
  148. elif "efficient" in summary_lower or "efficiently" in summary_lower:
  149. result = "高效"
  150. elif "robust" in summary_lower or "robustly" in summary_lower:
  151. result = "鲁棒性强"
  152. elif "generalize" in summary_lower or "generalization" in summary_lower:
  153. result = "泛化能力强"
  154. elif "few-shot" in summary_lower or "few shot" in summary_lower:
  155. result = "少样本学习"
  156. elif "sim-to-real" in summary_lower or "sim2real" in summary_lower:
  157. result = "仿真到现实迁移"
  158. else:
  159. result = "性能优化"
  160. # 格式:提出 XXX 框架,采用 XXX 技术,解决 XXX 问题,实现 XXX 效果
  161. brief = f"提出{core_method},采用{method}解决{field},实现{result}"
  162. # 从摘要提取具体标签(4-6 个),优先提取论文具体技术标签
  163. tags = []
  164. # 核心方法标签
  165. if "diffusion" in summary_lower:
  166. tags.append("扩散模型")
  167. if "reinforcement learning" in summary_lower:
  168. tags.append("强化学习")
  169. if "imitation learning" in summary_lower:
  170. tags.append("模仿学习")
  171. if "contrastive" in summary_lower:
  172. tags.append("对比学习")
  173. if "transformer" in summary_lower:
  174. tags.append("Transformer")
  175. if "self-supervised" in summary_lower:
  176. tags.append("自监督学习")
  177. if "multi-agent" in summary_lower or "marl" in summary_lower:
  178. tags.append("多智能体强化学习")
  179. if "world model" in summary_lower:
  180. tags.append("世界模型")
  181. if "residual policy" in summary_lower:
  182. tags.append("残差策略优化")
  183. if "preference optimization" in summary_lower:
  184. tags.append("偏好优化")
  185. if "representation learning" in summary_lower:
  186. tags.append("表征学习")
  187. if "adaptation" in summary_lower or "adaptive" in summary_lower:
  188. tags.append("自适应")
  189. # 具体任务标签
  190. if "robot" in summary_lower and "manipulation" in summary_lower:
  191. tags.append("机器人操作")
  192. if "dexterous" in summary_lower:
  193. tags.append("灵巧操作")
  194. if "navigation" in summary_lower:
  195. tags.append("导航")
  196. if "driving" in summary_lower or "racing" in summary_lower:
  197. tags.append("自动驾驶")
  198. if "cloth" in summary_lower:
  199. tags.append("布料操作")
  200. if "piano" in summary_lower:
  201. tags.append("音乐演奏")
  202. if "humanoid" in summary_lower or "motion" in summary_lower:
  203. tags.append("动作生成")
  204. if "localization" in summary_lower or "pose estimation" in summary_lower:
  205. tags.append("定位")
  206. if "traffic" in summary_lower:
  207. tags.append("交通场景")
  208. if "map" in summary_lower:
  209. tags.append("建图")
  210. # 结果标签
  211. if "zero-shot" in summary_lower:
  212. tags.append("零样本")
  213. if "real-world" in summary_lower:
  214. tags.append("真实部署")
  215. if "deployment" in summary_lower:
  216. tags.append("部署")
  217. if "sim-to-real" in summary_lower or "sim2real" in summary_lower:
  218. tags.append("仿真到现实")
  219. if "generalization" in summary_lower:
  220. tags.append("泛化能力")
  221. if "few-shot" in summary_lower:
  222. tags.append("少样本")
  223. if "efficient" in summary_lower:
  224. tags.append("高效")
  225. if "robust" in summary_lower:
  226. tags.append("鲁棒性")
  227. # 如果标签数量不足 4 个,添加领域特定标签
  228. if len(tags) < 4:
  229. domain_tags = {
  230. "embodied": ["具身智能", "机器人", "真实部署", "操控", "灵巧操作"],
  231. "representation": ["表征学习", "潜在空间", "世界模型", "预训练", "自监督"],
  232. "reinforcement": ["强化学习", "策略优化", "奖励设计", "离线 RL", "模仿学习"],
  233. }
  234. fallback = domain_tags.get(paper["domain"], ["AI 论文", "机器学习", "应用研究", "深度学习"])
  235. for tag in fallback:
  236. if tag not in tags:
  237. tags.append(tag)
  238. if len(tags) >= 6:
  239. break
  240. # 去重并限制数量
  241. tags = list(dict.fromkeys(tags))[:6]
  242. return {
  243. "translated_abstract_zh": f"【LLM 暂不可用,先保留英文摘要要点】{truncate(summary, 220)}",
  244. "brief_explanation_zh": truncate(brief, 86),
  245. "tags": tags,
  246. }
  247. def enrich_paper(paper: Dict[str, Any], model_names: List[str]) -> Dict[str, Any]:
  248. prompt = build_prompt(paper)
  249. result = None
  250. used_model = ""
  251. for model in model_names:
  252. model = normalize_space(model)
  253. if not model:
  254. continue
  255. log(f"Enriching {paper['arxiv_id']} with {model}")
  256. result = ollama_generate_json(prompt, model=model, timeout=150)
  257. if result:
  258. used_model = model
  259. break
  260. enriched = dict(paper)
  261. payload = result or fallback_enrichment(paper)
  262. tags = [normalize_space(tag).lstrip("#") for tag in payload.get("tags", []) if normalize_space(tag)]
  263. if not tags:
  264. tags = FALLBACK_TAGS.get(paper["domain"], [])[:5]
  265. enriched["translated_abstract_zh"] = normalize_space(payload.get("translated_abstract_zh", "")) or fallback_enrichment(paper)["translated_abstract_zh"]
  266. enriched["brief_explanation_zh"] = normalize_space(payload.get("brief_explanation_zh", "")) or fallback_enrichment(paper)["brief_explanation_zh"]
  267. enriched["tags"] = tags[:6]
  268. enriched["enrichment_model"] = used_model or "fallback"
  269. return enriched
  270. def enrich_selection(selection_payload: Dict[str, Any], model_names: List[str]) -> Dict[str, Any]:
  271. papers = selection_payload.get("papers", [])
  272. enriched_papers = [enrich_paper(paper, model_names=model_names) for paper in papers]
  273. by_domain: Dict[str, List[Dict[str, Any]]] = {domain: [] for domain in selection_payload.get("selected_by_domain", {})}
  274. for paper in enriched_papers:
  275. by_domain.setdefault(paper["domain"], []).append(paper)
  276. output = dict(selection_payload)
  277. output["papers"] = enriched_papers
  278. output["selected_by_domain"] = by_domain
  279. output["configured_models"] = model_names
  280. output["effective_models_used"] = list(
  281. dict.fromkeys(
  282. paper.get("enrichment_model", "")
  283. for paper in enriched_papers
  284. if paper.get("enrichment_model")
  285. )
  286. )
  287. return output
  288. def main() -> None:
  289. parser = argparse.ArgumentParser(description="Enrich RobotDaily papers with zh translation and tags")
  290. parser.add_argument("--input", required=True)
  291. parser.add_argument("--output", default="")
  292. parser.add_argument("--models", default="qwen3.5:27b")
  293. args = parser.parse_args()
  294. payload = read_json(args.input, default={}) or {}
  295. models = [item.strip() for item in args.models.split(",") if item.strip()]
  296. enriched = enrich_selection(payload, model_names=models)
  297. if args.output:
  298. write_json(args.output, enriched)
  299. else:
  300. print(json.dumps(enriched, ensure_ascii=False, indent=2))
  301. if __name__ == "__main__":
  302. main()