select_papers.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. #!/usr/bin/env python3
  2. """Select 2-3 promising papers per domain for RobotDaily."""
  3. from __future__ import annotations
  4. import argparse
  5. import json
  6. from collections import defaultdict
  7. from typing import Any, Dict, List
  8. from fetch_arxiv import DOMAIN_CONFIGS
  9. from utils import log, now_local, read_json, write_json
  10. DOMAIN_ORDER = ["embodied", "representation", "reinforcement"]
  11. def paper_sort_key(paper: Dict[str, Any]) -> Any:
  12. return (
  13. paper.get("score_total", 0.0),
  14. paper.get("score_applied", 0.0),
  15. paper.get("score_innovation", 0.0),
  16. paper.get("published", ""),
  17. )
  18. def selection_reason(paper: Dict[str, Any]) -> str:
  19. reasons: List[str] = []
  20. applied = paper.get("matched_applied_terms", [])[:3]
  21. innovation = paper.get("matched_innovation_terms", [])[:3]
  22. if applied:
  23. reasons.append("应用信号: " + ", ".join(applied))
  24. if innovation:
  25. reasons.append("创新信号: " + ", ".join(innovation))
  26. domain_matches = paper.get("domain_matches", {}).get(paper.get("domain", ""), [])[:3]
  27. if domain_matches:
  28. reasons.append("领域匹配: " + ", ".join(domain_matches))
  29. if not reasons:
  30. reasons.append("综合得分较高,且发布时间较新")
  31. return ";".join(reasons)
  32. def choose_domain_papers(papers: List[Dict[str, Any]], min_per_domain: int = 2, max_per_domain: int = 3) -> List[Dict[str, Any]]:
  33. ranked = sorted(papers, key=paper_sort_key, reverse=True)
  34. if not ranked:
  35. return []
  36. selected = ranked[:max_per_domain]
  37. if len(selected) >= 3:
  38. score_gap = selected[1].get("score_total", 0.0) - selected[2].get("score_total", 0.0)
  39. if score_gap > 1.2 or selected[2].get("score_total", 0.0) < 4.2:
  40. selected = selected[:2]
  41. if len(selected) < min_per_domain:
  42. selected = ranked[: min(min_per_domain, len(ranked))]
  43. output: List[Dict[str, Any]] = []
  44. for index, paper in enumerate(selected, start=1):
  45. enriched = dict(paper)
  46. enriched["domain_rank"] = index
  47. enriched["selection_reason"] = selection_reason(paper)
  48. output.append(enriched)
  49. return output
  50. def select_papers(candidates: List[Dict[str, Any]], min_per_domain: int = 2, max_per_domain: int = 3) -> Dict[str, Any]:
  51. grouped: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
  52. for paper in candidates:
  53. grouped[paper.get("domain", "representation")].append(paper)
  54. selected_by_domain: Dict[str, List[Dict[str, Any]]] = {}
  55. flat_selected: List[Dict[str, Any]] = []
  56. for domain in DOMAIN_ORDER:
  57. picked = choose_domain_papers(grouped.get(domain, []), min_per_domain=min_per_domain, max_per_domain=max_per_domain)
  58. selected_by_domain[domain] = picked
  59. flat_selected.extend(picked)
  60. flat_selected.sort(key=lambda item: (DOMAIN_ORDER.index(item["domain"]), item.get("domain_rank", 0)))
  61. return {
  62. "generated_at": now_local().isoformat(),
  63. "counts": {domain: len(selected_by_domain[domain]) for domain in DOMAIN_ORDER},
  64. "selected_by_domain": selected_by_domain,
  65. "papers": flat_selected,
  66. }
  67. def main() -> None:
  68. parser = argparse.ArgumentParser(description="Select daily papers for RobotDaily")
  69. parser.add_argument("--input", required=True)
  70. parser.add_argument("--output", default="")
  71. parser.add_argument("--min-per-domain", type=int, default=2)
  72. parser.add_argument("--max-per-domain", type=int, default=3)
  73. args = parser.parse_args()
  74. payload = read_json(args.input, default={}) or {}
  75. candidates = payload.get("papers", []) if isinstance(payload, dict) else []
  76. selected = select_papers(
  77. candidates,
  78. min_per_domain=args.min_per_domain,
  79. max_per_domain=args.max_per_domain,
  80. )
  81. log("Selected papers per domain: " + json.dumps(selected["counts"], ensure_ascii=False))
  82. if args.output:
  83. write_json(args.output, selected)
  84. else:
  85. print(json.dumps(selected, ensure_ascii=False, indent=2))
  86. if __name__ == "__main__":
  87. main()