#!/usr/bin/env python # Copyright (c) OpenMMLab. All rights reserved. import os.path as osp from argparse import ArgumentParser import mmcv from mmcv.utils import ProgressBar from mmocr.apis import init_detector, model_inference from mmocr.models import build_detector # noqa: F401 from mmocr.utils import list_from_file, list_to_file def gen_target_path(target_root_path, src_name, suffix): """Gen target file path. Args: target_root_path (str): The target root path. src_name (str): The source file name. suffix (str): The suffix of target file. """ assert isinstance(target_root_path, str) assert isinstance(src_name, str) assert isinstance(suffix, str) file_name = osp.split(src_name)[-1] name = osp.splitext(file_name)[0] return osp.join(target_root_path, name + suffix) def save_results(result, out_dir, img_name, score_thr=0.3): """Save result of detected bounding boxes (quadrangle or polygon) to txt file. Args: result (dict): Text Detection result for one image. img_name (str): Image file name. out_dir (str): Dir of txt files to save detected results. score_thr (float, optional): Score threshold to filter bboxes. """ assert 'boundary_result' in result assert score_thr > 0 and score_thr < 1 txt_file = gen_target_path(out_dir, img_name, '.txt') valid_boundary_res = [ res for res in result['boundary_result'] if res[-1] > score_thr ] lines = [ ','.join([str(round(x)) for x in row]) for row in valid_boundary_res ] list_to_file(txt_file, lines) def main(): parser = ArgumentParser() parser.add_argument('img_root', type=str, help='Image root path') parser.add_argument('img_list', type=str, help='Image path list file') parser.add_argument('config', type=str, help='Config file') parser.add_argument('checkpoint', type=str, help='Checkpoint file') parser.add_argument( '--score-thr', type=float, default=0.5, help='Bbox score threshold') parser.add_argument( '--out-dir', type=str, default='./results', help='Dir to save ' 'visualize images ' 'and bbox') parser.add_argument( '--device', default='cuda:0', help='Device used for inference.') args = parser.parse_args() assert 0 < args.score_thr < 1 # build the model from a config file and a checkpoint file model = init_detector(args.config, args.checkpoint, device=args.device) if hasattr(model, 'module'): model = model.module # Start Inference out_vis_dir = osp.join(args.out_dir, 'out_vis_dir') mmcv.mkdir_or_exist(out_vis_dir) out_txt_dir = osp.join(args.out_dir, 'out_txt_dir') mmcv.mkdir_or_exist(out_txt_dir) lines = list_from_file(args.img_list) progressbar = ProgressBar(task_num=len(lines)) for line in lines: progressbar.update() img_path = osp.join(args.img_root, line.strip()) if not osp.exists(img_path): raise FileNotFoundError(img_path) # Test a single image result = model_inference(model, img_path) img_name = osp.basename(img_path) # save result save_results(result, out_txt_dir, img_name, score_thr=args.score_thr) # show result out_file = osp.join(out_vis_dir, img_name) kwargs_dict = { 'score_thr': args.score_thr, 'show': False, 'out_file': out_file } model.show_result(img_path, result, **kwargs_dict) print(f'\nInference done, and results saved in {args.out_dir}\n') if __name__ == '__main__': main()