# GeoPool-Net **Repository Path**: GFancy/geo-pool-net ## Basic Information - **Project Name**: GeoPool-Net - **Description**: 用池化操作实现特征提取 - **Primary Language**: Unknown - **License**: Apache-2.0 - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2026-03-08 - **Last Updated**: 2026-03-09 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # SemanticGuidedNet + TextROINet **语义引导局部采样网络** — "深层找位置,浅层取细节" > 核心思想:用深层特征预测 ROI 坐标,在像素空间裁剪局部区域,与全局路径融合,通过共享骨干权重提升对小目标的感知能力。 > 同一套 ROI 机制分别迁移到**目标检测**(SemanticGuidedNet v3)和**场景文字识别**(TextROINet)两个任务。 --- ## 项目结构 ``` ImageTrain/ ├── models/ │ ├── dual_branch_net.py # SemanticGuidedNet v3(目标检测,12.6M) │ ├── loss.py # CIoU + FocalLoss + YOLOLoss │ ├── text_roi_net.py # TextROINet(场景文字识别,4.32M) │ └── str_loss.py # CTCLoss + ROIRegularizer + CenterDiversityLoss ├── utils/ │ ├── dataset.py # YOLO格式目标检测数据集 │ └── str_dataset.py # STR数据集(folder/lmdb格式) ├── scripts/ │ ├── download_str_data.py # 合成STR数据生成(~1000词表,v2) │ └── prepare_coco128.py # COCO128数据适配脚本 ├── train.py # 目标检测训练入口 ├── train_str.py # STR训练入口(支持课程学习、CSV日志) ├── verify.py # 快速验证脚本 ├── analyze.py # 训练结果分析脚本 └── RESEARCH.md # 完整研究总结(含架构图、实验分析) ``` --- ## 快速开始 ### 环境要求 ```bash pip install torch torchvision Pillow numpy ``` ### 任务一:场景文字识别(STR) #### Step 1 — 生成合成训练数据 ```bash python scripts/download_str_data.py --synthetic 10000 --out data/str_synth_v2 ``` 生成 10000 张合成英文单词图像(128×32 px): - ~1000 词表:短词(2-4字) 30% / 中词(5-7字) 50% / 长词(8-12字) 20% - 6 种配色方案(浅底深字 / 深底浅字) - 动态字体使文字填满 85-98% 图像宽度 - 40% 轻微旋转 ±5°,20% 仿射剪切 ±0.05 #### Step 2 — 训练 ROI 模型(课程学习,推荐) ```bash python train_str.py --data data/str_synth_v2 --data-format folder \ --epochs 50 --batch 64 --workers 4 \ --curriculum --curriculum-warmup 5 \ --lambda-center 0.05 --log-csv \ --save-dir runs/str_roi_v2 ``` 课程学习策略: - epoch 1-5:冻结 ROI 预测器,让 CTC 主路先收敛 - epoch 6+:解冻 ROI 预测器,学习率降至 3e-4 #### Step 3 — 训练 Baseline 对比模型 ```bash python train_str.py --data data/str_synth_v2 --data-format folder \ --epochs 50 --batch 64 --workers 4 \ --baseline --log-csv \ --save-dir runs/str_baseline_v2 ``` #### Smoke Test(快速验证,8 epoch) ```bash python train_str.py --data data/str_synth_v2 --data-format folder \ --epochs 8 --batch 16 --workers 0 \ --curriculum --curriculum-warmup 3 \ --lambda-center 0.05 --log-csv \ --save-dir runs/str_smoke_v2 ``` --- ### 任务二:目标检测(SemanticGuidedNet) #### 准备 COCO128 数据 ```bash python scripts/prepare_coco128.py ``` #### 训练 v3 ROI 模型 ```bash python train.py --epochs 100 --batch 4 --save-dir runs/v3 ``` #### 训练 Baseline 对比 ```bash python train.py --baseline --epochs 100 --batch 4 --save-dir runs/baseline ``` --- ## 模型架构 ### TextROINet(STR) ``` 输入 [B, 3, 32, 128] │ ├─── ThinCNN.forward_deep ──► deep_feat [B,256,4,64] │ └► seq_global [B,256,1,64] │ │ deep_feat → ROIPredictor → center[B,2], scale[B,1] │ ├─── pixel_roi_crop(x, center, scale, out=(32,64)) │ └► roi_patch [B,3,32,64] │ └─ ThinCNN.forward_deep ─► seq_local [B,256,1,32] │ └─ interpolate ─► [B,256,1,64] │ cat([seq_global, seq_local_up]) [B,512,1,64] └─ C2fSimple(512→256) ─► BiLSTM(256→512) ─► Linear(512→37) ─► [T=64, B, 37] ``` **Baseline**:用 `cat(seq_global, seq_global)` 替代局部路径,参数量对齐。 ### SemanticGuidedNet v3(目标检测) ``` 输入图像 └─ 共享 Backbone (stem+stage1) ─► 全局特征 s1_global └─ ROIPredictor (深层 p5) ──────► center, scale └─ pixel_roi_crop ──────────────► roi_patch └─ 共享 stem+stage1 ────────► ls1 → interpolate → ls1_up cat(s1_global, ls1_up) → C2f → FPN p3 检测头 ``` --- ## 损失函数 ### STR 损失 | 组件 | 公式 | 作用 | |------|------|------| | CTC Loss | `-log P(y|x)` | 主任务 | | ROI Regularizer | `relu(scale - 0.6).mean()` | 防止 scale 退化为全图 | | CenterDiversityLoss | `-std(center, dim=0).mean()` | 鼓励批内 ROI 中心多样性,防崩溃 | 总损失:`total = ctc + λ_roi × roi_reg + λ_center × center_div` --- ## 关键超参数 | 参数 | 默认值 | 说明 | |------|--------|------| | `--lambda-roi` | 0.1 | ROI 正则化权重 | | `--lambda-center` | 0.05 | 中心多样性损失权重 | | `--curriculum-warmup` | 5 | 冻结 ROI 预测器的 epoch 数 | | `roi_min_scale` | 0.4 | ROI 最小缩放比例 | | `roi_max_scale` | 0.7 | ROI 最大缩放比例 | --- ## 已知问题 & TODO ### ROI 退化问题 - 现象:训练后期 center std ≈ 0,所有样本 ROI 中心固定到 [0.5, 0.5] - 当前缓解:`CenterDiversityLoss`(对抗 center 退化)+ 课程学习(让 CTC 先收敛) - TODO:观察 50 epoch 后 center std 是否显著改善 ### 后续实验 1. **P1(必做)** 加入独立验证集重跑 50 epoch,用 word_acc 做最终对比 2. **P1** ROI 可视化:在原图标注裁剪框,验证 center 是否对准关键文字区域 3. **P2** 尝试更强的 CenterDiversityLoss 权重(λ_center=0.1~0.2),加速 ROI 探索 4. **中期** 真实 STR 数据集(MJSynth / IIIT5K)验证泛化 5. **中期** 多 ROI 扩展(N 个并行局部截取,投票融合) 6. **长期** 递归 ROI(ROI 内再预测 ROI) --- ## 实验结果记录 | 实验 | 任务 | Epoch | 最终 train_ctc | 备注 | |------|------|-------|----------------|------| | smoke_v1 | 检测 | 3 | loss 15.9 | v1 概念验证 | | smoke_pixel | 检测 | 3 | loss 3.02 | v3 像素空间 ROI | | v3 | 检测 | 100 | val_loss ~2.09 | batch=4,ROI明显优于baseline | | baseline(检测) | 检测 | 100 | val_loss ~14.7 | v3 优势明显 | | str_smoke_v2 | STR | 8 | 0.007 | v2改进+课程学习 | | **str_roi_v2** | **STR** | **50** | **0.000174** | **课程学习,ROI center std 最终 ≈ 0.27** | | **str_baseline_v2** | **STR** | **50** | **0.000128** | **纯CTC,无ROI** | ### STR 50-epoch 实验分析 **ROI 模型(str_roi_v2)关键观测:** | 阶段 | Epoch | train_ctc | center_std (y) | scale | 现象 | |------|-------|-----------|----------------|-------|------| | CTC 预热(冻结ROI) | 1→5 | 4.49→0.065 | ~0 | 0.550 | CTC 快速收敛,ROI 静止 | | ROI 解冻初期 | 6→10 | 0.020→0.002 | 0→0.004 | 0.550→0.550 | center 开始微小扰动 | | ROI 激活 | 11→20 | 0.0015→0.00084 | 0.020→0.173 | 0.551→0.552 | center_std 指数增长,ROI 开始搜索 | | ROI 稳定 | 21→35 | ~0.0007 | 0.18→0.27 | 0.543→0.406 | scale 下降(ROI 缩小聚焦),std 继续增大 | | 收敛期 | 36→50 | ~0.00013 | ~0.28 | ~0.403 | ROI 参数稳定,scale≈0.40(原始0.55→缩小22%) | **结论:** - ✅ `CenterDiversityLoss` 成功对抗退化:center std 从 0 增长到 ~0.27,ROI 不再固定到 [0.5, 0.5] - ✅ scale 从初始 0.55 下降到 0.40,说明 ROI 学会了聚焦更小的文字区域 - ✅ 课程学习有效:CTC 在 5 epoch 内从 4.49 降至 0.065,再开放 ROI 学习局部细节 - ⚠️ train_ctc 收敛值 ROI(0.000174) > baseline(0.000128):ROI 引入了额外优化难度 - ⚠️ 本实验无验证集(val=nan):需要引入独立测试集才能判断泛化能力 **下一步:** 用独立验证集(`--val`)重跑,观察 ROI 是否带来 word_acc 提升。 --- ## 代码验证 ```bash # 验证模型形状 python -c " from models.text_roi_net import build_text_model import torch m = build_text_model() lp, c, s = m(torch.randn(4,3,32,128), return_roi=True) print(lp.shape, c[0], s[0].item()) # 期望: [64,4,37] [0.5,0.5] 0.55 " # 验证损失函数 python -c " import torch from models.str_loss import STRLoss crit = STRLoss(blank=36, lambda_center=0.05) lp = torch.randn(64,4,37).log_softmax(-1) tgt = torch.tensor([0,1,2,3,0,1,4,5,0,1,0,1], dtype=torch.long) loss, d = crit(lp, tgt, torch.full((4,),64,dtype=torch.long), torch.tensor([3,2,3,4],dtype=torch.long), scale=torch.rand(4,1)*0.3+0.4, center=torch.rand(4,2)*0.3+0.35) print(d) # 期望: center_div 为负值 " # 验证检测模型 python -c " from models.dual_branch_net import build_model import torch m = build_model(baseline=False) out = m(torch.randn(2,3,640,640)) print([o.shape for o in out]) # 期望: 三个检测头输出 " ``` --- ## 参考文献 - YOLOv8 架构:Ultralytics YOLO - CTC Loss:Graves et al., 2006 - Spatial Transformer Networks:Jaderberg et al., 2015 - Deep Text Recognition Benchmark:Baek et al., 2019