# PyTorch模型剪枝

pytorch官方的剪枝工具分为非结构化和结构化剪枝两种，非结构化剪枝会随机地把一些权重参数变为0，结构化剪枝则将某个维度某些通道随机变成0，但这套工具不会真正输出剪枝后的模型，只是将模型变稀疏了，只有用某些特殊前向库，才能加快模型运行速度。最近找到一个库，能根据一定策略找到权重中作用较小的部分，用index表示，并保留对应的模型。

## 原理

首先构造输入，前向运行一次模型，得到模型对应的计算图

```python
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))
```

对于计算图中的各种带权重的层，根据指定策略（目前支持ln,l1,l2等）比较权重中各个数值，找到第k小的数对应的index，记为将要删除的部分，

```python
strategy = tp.strategy.L1Strategy() 
# 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4) # or manually selected pruning_idxs=[2, 6, 9, ...]
```

得到index后，对依赖该层的其他层，递归使用对应的prune函数进行剪枝，得到每一层的剪枝计划。

```python
pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )
```

在剪枝计划执行时，根据index改变model中对应层的定义，使得model中的channel数变少。

```python
# plune plane exec source code
def exec(self, dry_run=False):
    num_pruned = 0
    for dep, idxs in self._plans: # idxs were computed by specified strategy
        _, n = dep(idxs, dry_run=dry_run)
        num_pruned += n
    return num_pruned
```

以卷积层为例，执行剪枝计划时，根据strategy提供的idx，对weight和bias进行修剪：

```python
class ConvPruning(BasePruningFunction):
    @staticmethod
    def prune_params(layer: nn.Module, idxs: Sequence[int]) -> nn.Module: 
        keep_idxs = list(set(range(layer.out_channels)) - set(idxs))
        layer.out_channels = layer.out_channels-len(idxs)
        if not layer.transposed:
            layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs])
        else:
            layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs])
        if layer.bias is not None:
            layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs])
        return layer
    
    @staticmethod
    def calc_nparams_to_prune(layer: nn.Module, idxs: Sequence[int]) -> int: 
        nparams_to_prune = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0)
        return nparams_to_prune
```

如此，剪枝后model.forward时，运行的卷积层就是剪枝过的版本啦。

## 样例

* [CIFAR10上ResNet18剪枝](https://github.com/VainF/Torch-Pruning/blob/master/examples/prune_resnet18_cifar10.py)
* 注意事项：
  * 剪枝是对遍历网络的每一个单独的层进行剪枝
  * 剪枝时虽然有一定的策略，但不能保证每个剪掉这些层之后损失就是最小的
  * 每层剪枝的比例可以不同，可以考虑人工将网络划分为几个部分，每个部分可以设置不同的剪枝比例，但具体应该设置多少，可以从剪枝后模型的执行结果进行评估，可以考虑写一个遍历算法，或者写个启发式搜索来找到最佳比例；
  * 剪枝后应当在新的模型上继续fine tune一定epoch，以得到最适合此网络结构的权重


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://blog.cweihang.io/ml/kit/torch_prune.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
