#!/usr/bin/env python3 """Select 2-3 promising papers per domain for RobotDaily.""" from __future__ import annotations import argparse import json from collections import defaultdict from typing import Any, Dict, List from fetch_arxiv import DOMAIN_CONFIGS from utils import log, now_local, read_json, write_json DOMAIN_ORDER = ["embodied", "representation", "reinforcement"] def paper_sort_key(paper: Dict[str, Any]) -> Any: return ( paper.get("score_total", 0.0), paper.get("score_applied", 0.0), paper.get("score_innovation", 0.0), paper.get("published", ""), ) def selection_reason(paper: Dict[str, Any]) -> str: reasons: List[str] = [] applied = paper.get("matched_applied_terms", [])[:3] innovation = paper.get("matched_innovation_terms", [])[:3] if applied: reasons.append("应用信号: " + ", ".join(applied)) if innovation: reasons.append("创新信号: " + ", ".join(innovation)) domain_matches = paper.get("domain_matches", {}).get(paper.get("domain", ""), [])[:3] if domain_matches: reasons.append("领域匹配: " + ", ".join(domain_matches)) if not reasons: reasons.append("综合得分较高,且发布时间较新") return ";".join(reasons) def choose_domain_papers(papers: List[Dict[str, Any]], min_per_domain: int = 2, max_per_domain: int = 3) -> List[Dict[str, Any]]: ranked = sorted(papers, key=paper_sort_key, reverse=True) if not ranked: return [] selected = ranked[:max_per_domain] if len(selected) >= 3: score_gap = selected[1].get("score_total", 0.0) - selected[2].get("score_total", 0.0) if score_gap > 1.2 or selected[2].get("score_total", 0.0) < 4.2: selected = selected[:2] if len(selected) < min_per_domain: selected = ranked[: min(min_per_domain, len(ranked))] output: List[Dict[str, Any]] = [] for index, paper in enumerate(selected, start=1): enriched = dict(paper) enriched["domain_rank"] = index enriched["selection_reason"] = selection_reason(paper) output.append(enriched) return output def select_papers(candidates: List[Dict[str, Any]], min_per_domain: int = 2, max_per_domain: int = 3) -> Dict[str, Any]: grouped: Dict[str, List[Dict[str, Any]]] = defaultdict(list) for paper in candidates: grouped[paper.get("domain", "representation")].append(paper) selected_by_domain: Dict[str, List[Dict[str, Any]]] = {} flat_selected: List[Dict[str, Any]] = [] for domain in DOMAIN_ORDER: picked = choose_domain_papers(grouped.get(domain, []), min_per_domain=min_per_domain, max_per_domain=max_per_domain) selected_by_domain[domain] = picked flat_selected.extend(picked) flat_selected.sort(key=lambda item: (DOMAIN_ORDER.index(item["domain"]), item.get("domain_rank", 0))) return { "generated_at": now_local().isoformat(), "counts": {domain: len(selected_by_domain[domain]) for domain in DOMAIN_ORDER}, "selected_by_domain": selected_by_domain, "papers": flat_selected, } def main() -> None: parser = argparse.ArgumentParser(description="Select daily papers for RobotDaily") parser.add_argument("--input", required=True) parser.add_argument("--output", default="") parser.add_argument("--min-per-domain", type=int, default=2) parser.add_argument("--max-per-domain", type=int, default=3) args = parser.parse_args() payload = read_json(args.input, default={}) or {} candidates = payload.get("papers", []) if isinstance(payload, dict) else [] selected = select_papers( candidates, min_per_domain=args.min_per_domain, max_per_domain=args.max_per_domain, ) log("Selected papers per domain: " + json.dumps(selected["counts"], ensure_ascii=False)) if args.output: write_json(args.output, selected) else: print(json.dumps(selected, ensure_ascii=False, indent=2)) if __name__ == "__main__": main()