PyTorch模型剪枝
原理
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224))strategy = tp.strategy.L1Strategy()
# 3. get a pruning plan from the dependency graph.
pruning_idxs = strategy(model.conv1.weight, amount=0.4) # or manually selected pruning_idxs=[2, 6, 9, ...]pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs )# plune plane exec source code
def exec(self, dry_run=False):
num_pruned = 0
for dep, idxs in self._plans: # idxs were computed by specified strategy
_, n = dep(idxs, dry_run=dry_run)
num_pruned += n
return num_pruned样例
最后更新于