#!/usr/bin/env python3 """ PostgreSQL 恢复脚本(优化版) 此脚本用于恢复由 pg_backup_s3.py 备份到 S3 的数据库 支持从配置文件读取信息、错误重试、进度显示和详细日志记录 """ import os import boto3 import gzip import subprocess import shutil from datetime import datetime import logging import yaml from botocore.exceptions import ClientError from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from tqdm import tqdm import argparse # 配置文件默认路径 DEFAULT_CONFIG_PATH = "pg_backup.yaml" # 日志设置 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', handlers=[ logging.StreamHandler(), logging.FileHandler('/var/log/pg_restore.log') ] ) logger = logging.getLogger('PG_Restore') def print_step(message): print(f"→ {message}") def load_config(config_path): """ 从 YAML 配置文件中加载配置信息 """ if not os.path.exists(config_path): logger.error(f"配置文件 {config_path} 不存在,请检查路径。") raise FileNotFoundError(f"配置文件 {config_path} 不存在。") with open(config_path, 'r') as f: config = yaml.safe_load(f) return config def get_s3_client(config): """ 创建 S3 客户端 """ return boto3.client( 's3', endpoint_url=config['s3_endpoint'], aws_access_key_id=config['s3_access_key'], aws_secret_access_key=config['s3_secret_key'], region_name='cn-sy1', config=boto3.session.Config(signature_version='s3v4') ) @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10), retry=retry_if_exception_type(ClientError)) def list_backup_files(config): """ 列出 S3 中的备份文件,并按时间倒序排列 """ try: s3 = get_s3_client(config) response = s3.list_objects_v2(Bucket=config['s3_bucket']) if 'Contents' not in response: print_step("S3 桶中没有找到备份文件") return [] files = [obj['Key'] for obj in response['Contents'] if obj['Key'].endswith('.gz')] files.sort(reverse=True) # 按时间倒序排列 if not files: print_step("没有找到 .gz 格式的备份文件") return [] return files except Exception as e: logger.error(f"获取备份列表失败: {str(e)}") raise class DownloadProgressPercentage: """ 下载进度显示 """ def __init__(self, filename, total_size): self._filename = filename self._size = total_size self._seen_so_far = 0 self._pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"下载 {filename}", leave=False) def __call__(self, bytes_amount): self._seen_so_far += bytes_amount self._pbar.update(bytes_amount) @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10), retry=retry_if_exception_type(ClientError)) def download_from_s3(config, file_name): """ 从 S3 下载备份文件 """ try: restore_dir = config.get('restore_dir', '/tmp/pg_restores') os.makedirs(restore_dir, exist_ok=True) local_path = os.path.join(restore_dir, file_name) s3 = get_s3_client(config) print_step(f"正在下载 {file_name}...") # 获取文件大小用于进度显示 file_size = s3.head_object(Bucket=config['s3_bucket'], Key=file_name)['ContentLength'] s3.download_file( Bucket=config['s3_bucket'], Key=file_name, Filename=local_path, Callback=DownloadProgressPercentage(file_name, file_size) ) print() # 换行 return local_path except Exception as e: logger.error(f"下载备份文件失败: {str(e)}") raise def decompress_file(compressed_path): """ 解压备份文件 """ try: print_step("正在解压备份文件...") decompressed_path = compressed_path[:-3] # 去掉 .gz 后缀 with gzip.open(compressed_path, 'rb') as f_in: with open(decompressed_path, 'wb') as f_out: shutil.copyfileobj(f_in, f_out) return decompressed_path except Exception as e: logger.error(f"解压备份文件失败: {str(e)}") raise def restore_database(config, sql_file): """ 执行数据库恢复 """ try: # 让用户选择恢复模式 print("\n请选择恢复模式:") print("1. 完全恢复 (先清空数据库,再恢复)") print("2. 追加恢复 (保留现有数据,只添加备份数据)") while True: try: mode = int(input("请输入选择(1或2): ")) if mode in [1, 2]: break print("输入无效,请输入 1 或 2") except ValueError: print("请输入有效的数字") env = os.environ.copy() env['PGPASSWORD'] = config['db_password'] # 完全恢复模式 if mode == 1: print_step("正在准备完全恢复...") temp_db = f"{config['db_name']}_temp" # 0. 先检查并删除已存在的临时数据库 print_step("正在清理可能存在的临时数据库...") drop_temp_cmd = [ 'sudo', '-u', 'postgres', 'psql', '-c', f"DROP DATABASE IF EXISTS {temp_db};" ] subprocess.run(drop_temp_cmd, check=True) # 1. 创建临时数据库 print_step("正在创建临时数据库...") create_temp_cmd = [ 'sudo', '-u', 'postgres', 'psql', '-c', f"CREATE DATABASE {temp_db} WITH OWNER {config['db_user']} ENCODING 'UTF8';" ] subprocess.run(create_temp_cmd, check=True) # 2. 将备份恢复到临时数据库 print_step("正在恢复数据到临时数据库...") restore_temp_cmd = [ 'psql', '-U', config['db_user'], '-h', 'localhost', '-d', temp_db, '-f', sql_file ] subprocess.run(restore_temp_cmd, env=env, check=True) # 3. 终止所有连接到原数据库的会话 print_step("正在终止原数据库连接...") terminate_cmd = [ 'sudo', '-u', 'postgres', 'psql', '-c', f"SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = '{config['db_name']}';" ] subprocess.run(terminate_cmd, check=True) # 4. 删除原数据库 print_step("正在清理原数据库...") drop_orig_cmd = [ 'sudo', '-u', 'postgres', 'psql', '-c', f"DROP DATABASE IF EXISTS {config['db_name']};" ] subprocess.run(drop_orig_cmd, check=True) # 5. 重命名临时数据库 print_step("正在完成恢复...") rename_cmd = [ 'sudo', '-u', 'postgres', 'psql', '-c', f"ALTER DATABASE {temp_db} RENAME TO {config['db_name']};" ] subprocess.run(rename_cmd, check=True) # 普通恢复操作 print_step("正在恢复数据库...") restore_cmd = [ 'psql', '-U', config['db_user'], '-h', 'localhost', '-d', config['db_name'], '-f', sql_file ] result = subprocess.run( restore_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True ) if result.returncode != 0: raise Exception(f"恢复失败: {result.stderr.strip()}") print_step("数据库恢复成功") except Exception as e: logger.error(f"数据库恢复失败: {str(e)}") raise def cleanup(file_path): """ 清理临时文件 """ try: if os.path.exists(file_path): os.remove(file_path) except Exception as e: logger.warning(f"清理文件失败: {str(e)}") def main(): parser = argparse.ArgumentParser(description="PostgreSQL 恢复脚本") parser.add_argument("-c", "--config", default=DEFAULT_CONFIG_PATH, help="配置文件路径") args = parser.parse_args() print("\n" + "=" * 50) print("PostgreSQL 恢复脚本") print("=" * 50 + "\n") try: config = load_config(args.config) # 列出备份文件 backup_files = list_backup_files(config) if not backup_files: return # 显示备份文件列表 print("\n可用的备份文件:") for i, file in enumerate(backup_files, 1): print(f"{i}. {file}") # 选择要恢复的备份 while True: try: choice = int(input("\n请输入要恢复的备份编号: ")) if 1 <= choice <= len(backup_files): selected_file = backup_files[choice - 1] break print("输入无效,请重新输入") except ValueError: print("请输入有效的数字") # 下载并恢复 compressed_path = download_from_s3(config, selected_file) sql_path = decompress_file(compressed_path) restore_database(config, sql_path) except Exception as e: print_step(f"[错误] {str(e)}") finally: # 清理临时文件 if 'compressed_path' in locals(): cleanup(compressed_path) if 'sql_path' in locals(): cleanup(sql_path) print("\n[操作完成]") if __name__ == "__main__": main()