1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
| """ 得到含有Image字段的数据集 author: Orch1d date: 2025-12-07 """ import json import os.path
data_root = './dataset' save_dir = './new_dataset'
import logging import colorlog import fire
def setup_logger(): handler = colorlog.StreamHandler() formatter = colorlog.ColoredFormatter( "%(log_color)s[%(asctime)s] %(levelname)-8s%(reset)s %(message)s", datefmt="%Y-%m-%d %H:%M:%S", reset=True, log_colors={ 'DEBUG': 'cyan', 'INFO': 'green', 'WARNING': 'yellow', 'ERROR': 'red', 'CRITICAL': 'red,bg_white', }, secondary_log_colors={}, style='%' ) handler.setFormatter(formatter) logger = colorlog.getLogger('example') if logger.hasHandlers(): return logger logger.addHandler(handler) logger.setLevel(logging.DEBUG) logger.propagate = False return logger
def main(**kwargs): log = setup_logger() log.info(kwargs) name = kwargs.get('dataset') if not name: log.error("请通过命令行参数 '--dataset' 指定数据集名称") return input_file_path = os.path.join(data_root, name) log.info(f"input_file_path: {input_file_path}") if not os.path.exists(input_file_path): log.error(f"path not exists: {input_file_path}") return output_file_path = os.path.join(save_dir, name + '_with_image.json') try: total_count = 0 records_with_image = [] with open(input_file_path, 'r', errors='ignore') as file: for line_num, line in enumerate(file, 1): line = line.strip() if not line: continue total_count += 1 try: js = json.loads(line) if js.get('image') is not None: records_with_image.append(js) except json.JSONDecodeError as e: log.error(f"failed to parse json in line {line_num}:{e}") continue log.info("=" * 40) log.info(f"数据统计:") log.info(f" - 原数据总条数: {total_count}") log.info(f" - 包含'image'字段的条数: {len(records_with_image)}") log.info(f" - 筛选比例: {len(records_with_image) / total_count * 100:.2f}%") log.info("=" * 40) if records_with_image: with open(output_file_path, 'w', encoding='utf-8') as out_file: for record in records_with_image: out_file.write(json.dumps(record, ensure_ascii=False) + '\n') log.info(f"saved in: {output_file_path}") else: log.warning("Image is not included in the raw data") except Exception as e: log.exception(f"Unexpected Exception: {e}")
if __name__ == '__main__': fire.Fire(main)
|