day2_task.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """
  2. Day 2 - K 近邻 练习
  3. 任务:实现 K 近邻 的核心算法
  4. """
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. class KNearestNeighbor:
  8. """K 近邻 类"""
  9. def __init__(self, learning_rate: float = 0.01):
  10. self.learning_rate = learning_rate
  11. self.weights = None
  12. self.bias = None
  13. def forward(self, X: np.ndarray) -> np.ndarray:
  14. """前向传播
  15. Args:
  16. X: 输入数据,shape: [n_samples, n_features]
  17. Returns:
  18. 预测结果,shape: [n_samples]
  19. """
  20. # TODO: 实现 f(x) = sign(w · x + b)
  21. raise NotImplementedError
  22. def compute_loss(self, X: np.ndarray, y: np.ndarray) -> float:
  23. """计算损失"""
  24. # TODO: 实现损失函数
  25. raise NotImplementedError
  26. def update(self, X_i: np.ndarray, y_i: int):
  27. """更新参数
  28. Args:
  29. X_i: 单个样本,shape: [n_features]
  30. y_i: 标签,值域 (-1, 1)
  31. """
  32. # TODO: 实现梯度下降更新
  33. raise NotImplementedError
  34. def fit(self, X: np.ndarray, y: np.ndarray, max_iter: int = 100):
  35. """训练模型"""
  36. # TODO: 实现训练循环
  37. raise NotImplementedError
  38. def plot_concept():
  39. """可视化概念"""
  40. # 生成二维数据
  41. np.random.seed(42)
  42. X = np.random.randn(100, 2)
  43. y = np.sign(X[:, 0] + X[:, 1] - 0.5) * 1
  44. # 绘制散点图
  45. plt.figure(figsize=(8, 6))
  46. scatter = plt.scatter(X[:, 0], X[:, 1], c=y, cmap="bwr", s=100, edgecolors="black")
  47. plt.xlabel("x1")
  48. plt.ylabel("x2")
  49. plt.title("Day 2 - K 近邻 可视化")
  50. plt.colorbar(scatter)
  51. plt.grid(True, alpha=0.3)
  52. plt.savefig("./plots/day2_concept.png", dpi=150)
  53. print(f"✅ 可视化已保存:plots/day2_concept.png")
  54. if __name__ == "__main__":
  55. plot_concept()