# bert-text-classification **Repository Path**: zhihonglin/bert-text-classification ## Basic Information - **Project Name**: bert-text-classification - **Description**: No description available - **Primary Language**: Unknown - **License**: Not specified - **Default Branch**: main - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2025-06-16 - **Last Updated**: 2026-03-12 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # 🤖 BERT 文本分类项目 [![Python Version](https://img.shields.io/badge/Python-3.10%2B-blue)](https://www.python.org/) [![Framework](https://img.shields.io/badge/Framework-PyTorch-orange)](https://pytorch.org/) [![License](https://img.shields.io/badge/License-MIT-green)](LICENSE) 基于 BERT 预训练模型的中文文本十分类项目,适合 NLP 初学者学习 BERT 微调的完整流程。 ## 📖 项目简介 本项目使用 `bert-base-chinese` 预训练模型,在其基础上添加一个全连接层进行文本分类微调。 - **任务类型**:十分类(新闻标题分类) - **模型结构**:BERT + 全连接层 - **模型大小**:约 400MB - **训练策略**:冻结 BERT 参数,只训练全连接层 ### 🏷️ 分类类别 | 标签ID | 类别 | | ---- | ------------------ | | 0 | 财经 (finance) | | 1 | 房产 (realty) | | 2 | 股票 (stocks) | | 3 | 教育 (education) | | 4 | 科技 (science) | | 5 | 社会 (society) | | 6 | 时政 (politics) | | 7 | 体育 (sports) | | 8 | 游戏 (game) | | 9 | 娱乐 (entertainment) | ## 📂 项目结构 ``` bert-text-classification/ ├── data/ # 数据目录 │ ├── train.txt # 训练集(原始文本) │ ├── dev.txt # 验证集 │ ├── test.txt # 测试集 │ ├── class.txt # 类别列表 │ └── 预处理完的数据格式的demo示例.txt ├── src/ # 源代码目录 │ ├── bt_config.py # 配置文件 │ ├── train.py # 训练主程序 │ ├── model/ │ │ └── bt_model.py # 模型定义 │ ├── data_handle/ │ │ ├── preprocess.py # 数据预处理 │ │ ├── dataset.py # Dataset类 │ │ └── data_loader.py # DataLoader创建 │ └── utils/ │ ├── checkpoint.py # 模型保存/加载 │ └── logger.py # 日志工具 ├── test/ # 测试脚本(学习用) │ ├── test1_path.py # 路径处理测试 │ ├── test2_看看bert的分词器怎么使用.py │ ├── test3_批量化处理数据.py │ ├── test4_测试dataloader.py │ ├── test5_看看bert模型.py │ └── test6_测试bert模型前向计算.py ├── doc/ # 文档目录 │ ├── tip.md # 开发技巧 │ └── 模型结构图.md # BERT结构说明 ├── bert_pretrain/ # 预训练模型目录(需自行下载) ├── saved_model/ # 保存的模型 └── readme.md ``` ## 🚀 快速开始 ### 1. 📥 克隆项目 ```bash git clone https://gitee.com/zhihonglin/bert-text-classification.git cd bert-text-classification # 进入项目根目录 ``` ### 2. 🐍 创建环境 ```bash conda create -n bert_cls python=3.10 conda activate bert_cls ``` ### 3. 📦 安装依赖 ```bash pip install torch --index-url https://download.pytorch.org/whl/cu118 pip install transformers pandas tqdm # 或者:pip install -r requirements.txt ``` ### 4. 🧠 下载预训练模型 从魔搭下载 `bert-base-chinese` 模型,放入 `bert_pretrain/` 目录。推荐使用命令行方式: ```bash pip install modelscope modelscope download --model google-bert/bert-base-chinese --local_dir ./bert_pretrain ``` > **其他下载方式**: > > - **SDK 下载**: > ```python > from modelscope import snapshot_download > model_dir = snapshot_download('google-bert/bert-base-chinese', cache_dir='./bert_pretrain') > ``` > - **Git 下载**: > ```bash > git lfs install > git clone https://www.modelscope.cn/google-bert/bert-base-chinese.git bert_pretrain > ``` 下载完成后,`bert_pretrain/` 目录应包含:`config.json`、`pytorch_model.bin`、`vocab.txt` ### 5. 🧹 数据预处理 将原始文本转换为模型可用的 JSONL 格式: ```bash python src/data_handle/preprocess.py ``` > **数据格式说明**: > > - 输入(train.txt):`文本内容\t标签ID` > - 输出(train.jsonl):`{"input_ids": [...], "attention_mask": [...], "label": 3}` ### 6. 🏋️ 开始训练 ```bash python src/train.py ``` > **提示**:程序支持断点续训,再次运行 `python src/train.py` 将自动加载最新的 checkpoint。 ## 💻 核心代码说明 ### 🧠 模型结构 ```python class BT_model(nn.Module): def __init__(self, config: BT_Config): super(BT_model, self).__init__() self.bert = BertModel.from_pretrained(config.bert_path) # 冻结 BERT 参数,只训练分类头 for param in self.bert.parameters(): param.requires_grad = False self.fc = nn.Linear(config.hidden_size, 10) # 10分类 def forward(self, input_ids, attention_mask): with torch.no_grad(): # 不计算 BERT 的梯度 outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) x1 = outputs.pooler_output # [CLS] 位置的向量 logits = self.fc(x1) return logits ``` ### ⚙️ 配置参数 | 参数 | 默认值 | 说明 | | ------------ | ---- | ---------- | | epochs | 4 | 训练轮次 | | batch_size | 4 | 批次大小 | | pad_size | 35 | 句子最大长度 | | lr | 5e-5 | 学习率 | | hidden_size | 768 | BERT 隐藏层维度 | ## 📚 学习路径建议 1. **test1_path.py** - 了解 Python 路径处理 2. **test2_看看bert的分词器怎么使用.py** - 学习 BERT Tokenizer 3. **test3_批量化处理数据.py** - 理解批量数据处理 4. **test4_测试dataloader.py** - 掌握 PyTorch DataLoader 5. **test5_看看bert模型.py** - 查看 BERT 模型结构 6. **test6_测试bert模型前向计算.py** - 理解模型前向传播 7. 阅读 **train.py** - 完整训练流程 8. 阅读 **doc/tip.md** - 开发技巧和注意事项 ## 🔗 参考资料 - [BERT 论文](https://arxiv.org/abs/1810.04805) - [Transformers 文档](https://huggingface.co/docs/transformers/) - [PyTorch 官方教程](https://pytorch.org/tutorials/) ## 💡 常见问题 **Q: CUDA out of memory?** A: 减小 batch_size 或 pad_size **Q: 找不到 bert_pretrain?** A: 确保已下载预训练模型到正确目录 **Q: import 报错?** A: 检查 PYTHONPATH 是否正确设置