梦里风林
  • Introduction
  • Android
    • activity
      • Activity四种启动模式
      • Intent Flag
      • 多task的应用
      • Task和回退栈
    • sqlite
      • 【源码】CursorWindow读DB
      • Sqlite在Android上的一个Bug
    • Chromium
    • ListView读取DB数据最佳实践
    • Android Project结构
    • 一个由Proguard与FastJson引起的血案
    • 琐碎的一些tips
  • Computer Vision
    • 特征提取
    • 三维视觉
    • 计算机视觉常用工具
    • 浅谈深度学习数据集设计
    • 随笔
  • Machine Learning
    • 技巧
      • FaceBook: 1 hour training ImageNet
      • L2 Norm与L2 normalize
    • 实践
      • Pytorch实验代码的亿些小细节
    • 工具
      • Tensorflow学习笔记
      • MXNet踩坑手记
      • PyTorch踩坑手记
      • PyTorch模型剪枝
      • Keras踩坑手记
      • mscnn
      • Matlab
        • Matlab Remote IPC自动化数据处理
    • Papers
      • Classification
      • Re-identification
        • CVPR2018:TFusion完全解读
        • ECCV2018:TAUDL
        • CVPR2018:Graph+reid
        • Person Re-identification
        • CVPR2016 Re-id
        • Camera topology and Person Re-id
        • Deep transfer learning Person Re-id
        • Evaluate
      • Object Detection
        • 读论文系列·干货满满的RCNN
        • 读论文系列·SPP-net
        • 读论文系列·Fast RCNN
        • 读论文系列·Faster RCNN
        • 读论文系列·YOLO
        • 读论文系列·SSD
        • 读论文系列·YOLOv2 & YOLOv3
        • 读论文系列·detection其他文章推荐
      • Depth
      • 3D vision
        • 数据集相关
        • 光流相关
      • Hashing
        • CVPR2018: SSAH
      • 大杂烩
        • CNCC2017 琐记
        • ECCV 2016 Hydra CCNN
        • CNCC2017深度学习与跨媒体智能
        • MLA2016笔记
    • 《机器学习》(周志华)读书笔记
      • 西瓜书概念整理
        • 绪论
        • 模型评估与选择
        • 线性模型
        • 决策树
        • 神经网络
        • 支持向量机
        • 贝叶斯分类器
        • 集成学习
        • 聚类
        • 降维与度量学习
        • 特征选择与稀疏学习
        • 计算学习理论
        • 半监督学习
        • 概率图模型
        • 规则学习
        • 强化学习
        • 附录
  • Java
    • java web
      • Servlet部署
      • 琐碎的tips
    • JNI
    • Note
    • Effective Java笔记
  • 后端开发
    • 架构设计
    • 数据库
    • java web
      • Servlet部署
      • 琐碎的tips
    • Spring boot
    • django
    • 分布式
  • Linux && Hardware
    • Ubuntu安装与初始配置
    • 树莓派相关
      • 树莓派3B+无线网卡监听模式
      • TP-LINK TL-WR703N v1.7 openwrt flashing
  • Python
    • django
    • 原生模块
    • 设计模式
    • 可视化
    • 常用库踩坑指南
  • web前端
    • header div固定,content div填充父容器
    • json接口资源
  • UI
  • kit
    • vim
    • git/github
      • 刷爆github小绿点
    • Markdown/gitbook
      • 琐碎知识点
      • gitbook添加disqus作为评论
      • 导出chrome书签为Markdown
      • Markdown here && 微信公众号
    • LaTex
      • LaTex琐记
    • 科学上网
    • 虚拟机
  • thinking-in-program
    • 怎样打日志
  • 我的收藏
  • 琐记
    • 论文心得
    • 深圳买房攻略
  • 赞赏支持
由 GitBook 提供支持
在本页
  • 原理
  • 样例

这有帮助吗?

  1. Machine Learning
  2. 工具

PyTorch模型剪枝

pytorch官方的剪枝工具分为非结构化和结构化剪枝两种,非结构化剪枝会随机地把一些权重参数变为0,结构化剪枝则将某个维度某些通道随机变成0,但这套工具不会真正输出剪枝后的模型,只是将模型变稀疏了,只有用某些特殊前向库,才能加快模型运行速度。最近找到一个库,能根据一定策略找到权重中作用较小的部分,用index表示,并保留对应的模型。

原理

首先构造输入,前向运行一次模型,得到模型对应的计算图

DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))

对于计算图中的各种带权重的层,根据指定策略(目前支持ln,l1,l2等)比较权重中各个数值,找到第k小的数对应的index,记为将要删除的部分,

strategy = tp.strategy.L1Strategy() 
# 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4) # or manually selected pruning_idxs=[2, 6, 9, ...]

得到index后,对依赖该层的其他层,递归使用对应的prune函数进行剪枝,得到每一层的剪枝计划。

pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )

在剪枝计划执行时,根据index改变model中对应层的定义,使得model中的channel数变少。

# plune plane exec source code
def exec(self, dry_run=False):
    num_pruned = 0
    for dep, idxs in self._plans: # idxs were computed by specified strategy
        _, n = dep(idxs, dry_run=dry_run)
        num_pruned += n
    return num_pruned

以卷积层为例,执行剪枝计划时,根据strategy提供的idx,对weight和bias进行修剪:

class ConvPruning(BasePruningFunction):
    @staticmethod
    def prune_params(layer: nn.Module, idxs: Sequence[int]) -> nn.Module: 
        keep_idxs = list(set(range(layer.out_channels)) - set(idxs))
        layer.out_channels = layer.out_channels-len(idxs)
        if not layer.transposed:
            layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs])
        else:
            layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs])
        if layer.bias is not None:
            layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
        return layer
    
    @staticmethod
    def calc_nparams_to_prune(layer: nn.Module, idxs: Sequence[int]) -> int: 
        nparams_to_prune = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0)
        return nparams_to_prune

如此,剪枝后model.forward时,运行的卷积层就是剪枝过的版本啦。

样例

  • 注意事项:

    • 剪枝是对遍历网络的每一个单独的层进行剪枝

    • 剪枝时虽然有一定的策略,但不能保证每个剪掉这些层之后损失就是最小的

    • 每层剪枝的比例可以不同,可以考虑人工将网络划分为几个部分,每个部分可以设置不同的剪枝比例,但具体应该设置多少,可以从剪枝后模型的执行结果进行评估,可以考虑写一个遍历算法,或者写个启发式搜索来找到最佳比例;

    • 剪枝后应当在新的模型上继续fine tune一定epoch,以得到最适合此网络结构的权重

上一页PyTorch踩坑手记下一页Keras踩坑手记

最后更新于3年前

这有帮助吗?

CIFAR10上ResNet18剪枝