diff --git a/README.md b/README.md index 48ae9ce..9cd1b0d 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,6 @@ **车牌颜色和车牌识别一起训练看这里: [车牌识别+车牌颜色](https://github.com/we0091234/crnn_plate_recognition/tree/plate_color)** - | 模型 | 准确率 | 速度(ms) | 模型大小(MB) | link | | ------ | ------ | -------- | ------------ | ---------------------------------------------------- | | small | 96.82% | 1.2ms | 0.67 | [ezhe](https://pan.baidu.com/s/1IsQNPSRuW7bXNWc2ULfFLg) | @@ -101,9 +100,6 @@ python export.py --weights saved_model/best.pth --save_path saved_model/best.onn ``` -导出onnx文件为 saved_model/best.onnx - -如果需要onnx支持trt的话,支持[这里推理](https://github.com/we0091234/chinese_plate_tensorrt),则加上--trt #### onnx 推理 diff --git a/export.py b/export.py index 95c60af..a6fc159 100644 --- a/export.py +++ b/export.py @@ -14,7 +14,7 @@ if __name__=="__main__": parser.add_argument('--batch_size', type=int, default=1, help='batch size') parser.add_argument('--dynamic', action='store_true', default=False, help='enable dynamic axis in onnx model') parser.add_argument('--simplify', action='store_true', default=False, help='simplified onnx') - parser.add_argument('--trt', action='store_true', default=False, help='support trt') + # parser.add_argument('--trt', action='store_true', default=False, help='support trt') @@ -22,7 +22,7 @@ if __name__=="__main__": print(opt) checkpoint = torch.load(opt.weights) cfg = checkpoint['cfg'] - model = myNet_ocr(num_classes=len(plate_chr),cfg=cfg,export=True,trt=opt.trt) + model = myNet_ocr(num_classes=len(plate_chr),cfg=cfg,export=True) model.load_state_dict(checkpoint['state_dict']) model.eval() diff --git a/plateNet.py b/plateNet.py index 63faa93..134d272 100644 --- a/plateNet.py +++ b/plateNet.py @@ -3,14 +3,13 @@ import torch import torch.nn.functional as F class myNet_ocr(nn.Module): - def __init__(self,cfg=None,num_classes=78,export=False,trt=False): + def __init__(self,cfg=None,num_classes=78,export=False): super(myNet_ocr, self).__init__() if cfg is None: cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256] # cfg =[32,32,'M',64,64,'M',128,128,'M',256,256] self.feature = self.make_layers(cfg, True) self.export = export - self.trt= trt # self.classifier = nn.Linear(cfg[-1], num_classes) # self.loc = nn.MaxPool2d((2, 2), (5, 1), (0, 1),ceil_mode=True) # self.loc = nn.AvgPool2d((2, 2), (5, 2), (0, 1),ceil_mode=False) @@ -48,9 +47,6 @@ class myNet_ocr(nn.Module): if self.export: conv = x.squeeze(2) # b *512 * width conv = conv.transpose(2,1) # [w, b, c] - if self.trt: - conv =conv.argmax(dim=2) - conv = conv.float() return conv else: b, c, h, w = x.size()