pg_backup_s3.py
· 13 KiB · Python
原始文件
#!/usr/bin/env python3
"""
PostgreSQL 压缩备份到 S3 一体化脚本
支持自动生成配置模板、环境变量覆盖、错误重试、进度显示
"""
import os
import subprocess
import boto3
from botocore.exceptions import ClientError
from datetime import datetime, timedelta
import logging
import gzip
import shutil
import yaml
from dataclasses import dataclass
from typing import List, NoReturn
import argparse
from tqdm import tqdm
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from boto3.s3.transfer import TransferConfig
# -------------------------
# 配置数据类与默认值
# -------------------------
@dataclass
class BackupConfig:
db_name: str = "your_database_name"
db_user: str = "your_db_user"
db_password: str = "your_db_password"
s3_endpoint: str = "https://s3.your-provider.com"
s3_access_key: str = "your_access_key"
s3_secret_key: str = "your_secret_key"
s3_bucket: str = "your-bucket-name"
backup_dir: str = "/var/lib/pg_backup"
compress_level: int = 6
keep_days: int = 7
log_file: str = "/var/log/pg_backup.log"
pg_host: str = "localhost"
pg_port: int = 5432
use_ssl: bool = False
# -------------------------
# 配置文件管理
# -------------------------
def create_env_template(env_path: str) -> None:
"""生成.env环境变量模板文件"""
template = """# PostgreSQL 备份到 S3 环境变量配置
# 将敏感信息放在这里,不要提交到版本控制
PG_DB_PASSWORD=your_db_password # 数据库密码
S3_ACCESS_KEY=your_access_key # S3 Access Key
S3_SECRET_KEY=your_secret_key # S3 Secret Key
"""
with open(env_path, "w") as f:
f.write(template)
print(f"已生成环境变量模板到 {env_path},请修改敏感信息")
# 添加文件权限设置
if os.name != 'nt': # 非Windows系统
os.chmod(env_path, 0o600) # 仅用户可读写
def create_config_template(config_path: str) -> None:
"""生成配置模板文件"""
default_config = BackupConfig()
template = f"""# PostgreSQL 备份到 S3 配置文件
# 敏感信息建议通过环境变量设置(优先级高于文件配置)
# 环境变量:PG_DB_PASSWORD, S3_ACCESS_KEY, S3_SECRET_KEY
db_name: {default_config.db_name} # 数据库名称
db_user: {default_config.db_user} # 数据库用户
# db_password: {default_config.db_password} # 数据库密码(推荐通过环境变量设置)
s3_endpoint: {default_config.s3_endpoint} # S3 端点(例如:https://s3.example.com)
s3_access_key: {default_config.s3_access_key} # S3 Access Key
# s3_secret_key: {default_config.s3_secret_key} # S3 Secret Key(推荐通过环境变量设置)
s3_bucket: {default_config.s3_bucket} # S3 存储桶名称
backup_dir: {default_config.backup_dir} # 本地备份存储目录(需可写)
keep_days: {default_config.keep_days} # 保留天数(删除超过天数的备份)
pg_host: {default_config.pg_host} # 数据库主机(默认localhost)
pg_port: {default_config.pg_port} # 数据库端口(默认5432)
use_ssl: {default_config.use_ssl} # 是否启用SSL连接(默认false)
log_file: {default_config.log_file} # 日志文件路径
compress_level: {default_config.compress_level} # 压缩级别(0-9,默认6)
"""
with open(config_path, "w") as f:
f.write(template)
print(f"已生成配置模板到 {config_path},请修改后重新运行")
# 同时生成.env文件模板
env_path = os.path.join(os.path.dirname(config_path), ".env")
if not os.path.exists(env_path):
create_env_template(env_path)
def load_or_create_config(config_path: str) -> BackupConfig:
"""加载配置文件,不存在则生成模板"""
if not os.path.exists(config_path):
create_config_template(config_path)
raise SystemExit(0)
# 加载.env文件(如果存在)
env_path = os.path.join(os.path.dirname(config_path), ".env")
if os.path.exists(env_path):
from dotenv import load_dotenv
load_dotenv(env_path)
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
# 环境变量覆盖敏感信息
env_override = {
"db_password": os.getenv("PG_DB_PASSWORD"),
"s3_access_key": os.getenv("S3_ACCESS_KEY"),
"s3_secret_key": os.getenv("S3_SECRET_KEY")
}
for key, value in env_override.items():
if value:
cfg[key] = value
return BackupConfig(**cfg)
# -------------------------
# 日志初始化
# -------------------------
def setup_logger(log_file: str) -> logging.Logger:
logger = logging.getLogger("PGBackup")
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# 控制台和文件日志
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(formatter)
file_handler = logging.FileHandler(log_file)
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
# -------------------------
# 核心功能类
# -------------------------
class BackupManager:
def __init__(self, config: BackupConfig):
self.config = config
self.logger = setup_logger(config.log_file)
self.s3_client = self._init_s3_client()
def _init_s3_client(self):
return boto3.client(
's3',
endpoint_url=self.config.s3_endpoint,
aws_access_key_id=self.config.s3_access_key,
aws_secret_access_key=self.config.s3_secret_key,
region_name='cn-sy1',
config=boto3.session.Config(signature_version='s3v4')
)
def check_prerequisites(self) -> bool:
try:
os.makedirs(self.config.backup_dir, exist_ok=True)
# 添加对备份目录可写性的检查
test_file = os.path.join(self.config.backup_dir, ".test")
with open(test_file, "w") as f:
f.write("test")
os.remove(test_file)
subprocess.run(["pg_dump", "--version"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
return True
except Exception as e:
self.logger.error(f"前置检查失败: {str(e)}")
return False
def create_compressed_backup(self) -> str:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_name = f"{self.config.db_name}_backup_{timestamp}"
sql_path = os.path.join(self.config.backup_dir, base_name + ".sql")
gz_path = sql_path + ".gz"
try:
# 检查关键参数是否为空
if not all([self.config.db_name, self.config.db_user, self.config.db_password,
self.config.pg_host, str(self.config.pg_port)]):
raise ValueError("数据库连接参数不完整")
env = os.environ.copy()
env["PGPASSWORD"] = self.config.db_password
pg_cmd = [
"pg_dump",
f"--dbname={self.config.db_name}",
f"--host={self.config.pg_host}",
f"--port={self.config.pg_port}",
f"--username={self.config.db_user}",
f"--file={sql_path}"
]
if self.config.use_ssl:
pg_cmd.append("--ssl-mode=require")
# 添加详细的日志记录
self.logger.debug(f"执行pg_dump命令: {' '.join(pg_cmd)}")
# 添加备份开始提示
print(f"[*] 开始数据库备份: {self.config.db_name}")
# 执行pg_dump并显示进度
with tqdm(desc="数据库导出", unit="B", unit_scale=True) as pbar:
def pg_dump_progress(line):
if "Dumping" in line:
pbar.set_postfix_str(line.strip())
pbar.update(0) # 更新进度条但不增加计数
result = subprocess.run(
pg_cmd,
env=env,
capture_output=True,
text=True,
bufsize=1,
universal_newlines=True
)
if result.returncode != 0:
error_msg = result.stderr.strip()
self.logger.error(f"pg_dump详细错误: {error_msg}")
raise RuntimeError(f"pg_dump失败: {error_msg}")
# 添加压缩进度显示
print(f"[*] 正在压缩备份文件: {os.path.basename(gz_path)}")
file_size = os.path.getsize(sql_path)
with tqdm(total=file_size, desc="压缩进度", unit="B", unit_scale=True) as pbar:
with open(sql_path, "rb") as f_in, gzip.open(gz_path, "wb", self.config.compress_level) as f_out:
shutil.copyfileobj(f_in, f_out, length=1024*1024) # 1MB chunks
pbar.update(f_in.tell())
os.remove(sql_path)
print(f"[✓] 备份文件已创建: {os.path.basename(gz_path)}")
return gz_path
except Exception as e:
self._cleanup_temp_files([sql_path, gz_path])
self.logger.error(f"备份创建失败: {str(e)}")
raise
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10), retry=retry_if_exception_type(ClientError))
def upload_to_s3(self, file_path: str) -> None:
file_name = os.path.basename(file_path)
print(f"[*] 开始上传到S3: {file_name}")
transfer_config = TransferConfig(multipart_threshold=10*1024**2, max_concurrency=10)
with open(file_path, "rb") as f, tqdm(
total=os.path.getsize(file_path),
unit="B", unit_scale=True,
desc="上传进度",
leave=True # 修改为True以保留进度条
) as pbar:
def progress(bytes_transferred):
pbar.update(bytes_transferred)
self.s3_client.upload_fileobj(f, self.config.s3_bucket, file_name, Config=transfer_config, Callback=progress)
print(f"[✓] 上传完成: {file_name}")
print(f" S3位置: {self.config.s3_bucket}/{file_name}")
def clean_old_backups(self) -> None:
cutoff = datetime.now() - timedelta(days=self.config.keep_days)
for file in os.listdir(self.config.backup_dir):
if file.endswith(".gz"):
path = os.path.join(self.config.backup_dir, file)
mtime = datetime.fromtimestamp(os.path.getmtime(path))
if mtime < cutoff:
try:
os.remove(path)
self.logger.info(f"删除过期备份: {file}")
except Exception as e:
self.logger.warning(f"删除失败: {file} - {str(e)}")
@staticmethod
def _cleanup_temp_files(files: List[str]) -> None:
for f in files:
if os.path.exists(f):
try:
os.remove(f)
except:
pass
# -------------------------
# 命令行接口
# -------------------------
def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="PostgreSQL 备份到 S3 工具")
parser.add_argument("-c", "--config", default="pg_backup.yaml", help="配置文件路径(默认: pg_backup.yaml)")
return parser.parse_args()
# -------------------------
# 主流程
# -------------------------
def check_dependencies():
"""检查所有依赖是否已安装"""
try:
import boto3
import yaml
import tqdm
import tenacity
import dotenv
except ImportError as e:
print(f"[!] 缺少依赖: {str(e)}")
print("请运行: pip install -r requirements.txt")
raise SystemExit(1)
# 在main()函数开头调用
def main():
check_dependencies()
args = parse_arguments()
config_path = args.config
try:
config = load_or_create_config(config_path)
manager = BackupManager(config)
print("[*] 开始备份流程")
if not manager.check_prerequisites():
raise SystemExit("[!] 前置检查失败,流程终止")
backup_path = manager.create_compressed_backup()
manager.upload_to_s3(backup_path)
manager.clean_old_backups()
os.remove(backup_path) # 清理本地文件
print("\n[✓] 备份流程成功完成")
except SystemExit:
raise
except Exception as e:
manager.logger.error(f"备份失败: {str(e)}", exc_info=True)
print(f"\n[!] 备份失败: {str(e)}")
raise SystemExit(1)
if __name__ == "__main__":
main()
1 | #!/usr/bin/env python3 |
2 | """ |
3 | PostgreSQL 压缩备份到 S3 一体化脚本 |
4 | 支持自动生成配置模板、环境变量覆盖、错误重试、进度显示 |
5 | """ |
6 | |
7 | import os |
8 | import subprocess |
9 | import boto3 |
10 | from botocore.exceptions import ClientError |
11 | from datetime import datetime, timedelta |
12 | import logging |
13 | import gzip |
14 | import shutil |
15 | import yaml |
16 | from dataclasses import dataclass |
17 | from typing import List, NoReturn |
18 | import argparse |
19 | from tqdm import tqdm |
20 | from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type |
21 | from boto3.s3.transfer import TransferConfig |
22 | # ------------------------- |
23 | # 配置数据类与默认值 |
24 | # ------------------------- |
25 | @dataclass |
26 | class BackupConfig: |
27 | db_name: str = "your_database_name" |
28 | db_user: str = "your_db_user" |
29 | db_password: str = "your_db_password" |
30 | s3_endpoint: str = "https://s3.your-provider.com" |
31 | s3_access_key: str = "your_access_key" |
32 | s3_secret_key: str = "your_secret_key" |
33 | s3_bucket: str = "your-bucket-name" |
34 | backup_dir: str = "/var/lib/pg_backup" |
35 | compress_level: int = 6 |
36 | keep_days: int = 7 |
37 | log_file: str = "/var/log/pg_backup.log" |
38 | pg_host: str = "localhost" |
39 | pg_port: int = 5432 |
40 | use_ssl: bool = False |
41 | |
42 | # ------------------------- |
43 | # 配置文件管理 |
44 | # ------------------------- |
45 | def create_env_template(env_path: str) -> None: |
46 | """生成.env环境变量模板文件""" |
47 | template = """# PostgreSQL 备份到 S3 环境变量配置 |
48 | # 将敏感信息放在这里,不要提交到版本控制 |
49 | |
50 | PG_DB_PASSWORD=your_db_password # 数据库密码 |
51 | S3_ACCESS_KEY=your_access_key # S3 Access Key |
52 | S3_SECRET_KEY=your_secret_key # S3 Secret Key |
53 | """ |
54 | with open(env_path, "w") as f: |
55 | f.write(template) |
56 | print(f"已生成环境变量模板到 {env_path},请修改敏感信息") |
57 | # 添加文件权限设置 |
58 | if os.name != 'nt': # 非Windows系统 |
59 | os.chmod(env_path, 0o600) # 仅用户可读写 |
60 | |
61 | def create_config_template(config_path: str) -> None: |
62 | """生成配置模板文件""" |
63 | default_config = BackupConfig() |
64 | template = f"""# PostgreSQL 备份到 S3 配置文件 |
65 | # 敏感信息建议通过环境变量设置(优先级高于文件配置) |
66 | # 环境变量:PG_DB_PASSWORD, S3_ACCESS_KEY, S3_SECRET_KEY |
67 | |
68 | db_name: {default_config.db_name} # 数据库名称 |
69 | db_user: {default_config.db_user} # 数据库用户 |
70 | # db_password: {default_config.db_password} # 数据库密码(推荐通过环境变量设置) |
71 | s3_endpoint: {default_config.s3_endpoint} # S3 端点(例如:https://s3.example.com) |
72 | s3_access_key: {default_config.s3_access_key} # S3 Access Key |
73 | # s3_secret_key: {default_config.s3_secret_key} # S3 Secret Key(推荐通过环境变量设置) |
74 | s3_bucket: {default_config.s3_bucket} # S3 存储桶名称 |
75 | backup_dir: {default_config.backup_dir} # 本地备份存储目录(需可写) |
76 | keep_days: {default_config.keep_days} # 保留天数(删除超过天数的备份) |
77 | pg_host: {default_config.pg_host} # 数据库主机(默认localhost) |
78 | pg_port: {default_config.pg_port} # 数据库端口(默认5432) |
79 | use_ssl: {default_config.use_ssl} # 是否启用SSL连接(默认false) |
80 | log_file: {default_config.log_file} # 日志文件路径 |
81 | compress_level: {default_config.compress_level} # 压缩级别(0-9,默认6) |
82 | """ |
83 | with open(config_path, "w") as f: |
84 | f.write(template) |
85 | print(f"已生成配置模板到 {config_path},请修改后重新运行") |
86 | |
87 | # 同时生成.env文件模板 |
88 | env_path = os.path.join(os.path.dirname(config_path), ".env") |
89 | if not os.path.exists(env_path): |
90 | create_env_template(env_path) |
91 | |
92 | def load_or_create_config(config_path: str) -> BackupConfig: |
93 | """加载配置文件,不存在则生成模板""" |
94 | if not os.path.exists(config_path): |
95 | create_config_template(config_path) |
96 | raise SystemExit(0) |
97 | |
98 | # 加载.env文件(如果存在) |
99 | env_path = os.path.join(os.path.dirname(config_path), ".env") |
100 | if os.path.exists(env_path): |
101 | from dotenv import load_dotenv |
102 | load_dotenv(env_path) |
103 | |
104 | with open(config_path, "r") as f: |
105 | cfg = yaml.safe_load(f) |
106 | |
107 | # 环境变量覆盖敏感信息 |
108 | env_override = { |
109 | "db_password": os.getenv("PG_DB_PASSWORD"), |
110 | "s3_access_key": os.getenv("S3_ACCESS_KEY"), |
111 | "s3_secret_key": os.getenv("S3_SECRET_KEY") |
112 | } |
113 | for key, value in env_override.items(): |
114 | if value: |
115 | cfg[key] = value |
116 | |
117 | return BackupConfig(**cfg) |
118 | |
119 | # ------------------------- |
120 | # 日志初始化 |
121 | # ------------------------- |
122 | def setup_logger(log_file: str) -> logging.Logger: |
123 | logger = logging.getLogger("PGBackup") |
124 | logger.setLevel(logging.DEBUG) |
125 | formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
126 | |
127 | # 控制台和文件日志 |
128 | console_handler = logging.StreamHandler() |
129 | console_handler.setLevel(logging.INFO) |
130 | console_handler.setFormatter(formatter) |
131 | |
132 | file_handler = logging.FileHandler(log_file) |
133 | file_handler.setLevel(logging.DEBUG) |
134 | file_handler.setFormatter(formatter) |
135 | |
136 | logger.addHandler(console_handler) |
137 | logger.addHandler(file_handler) |
138 | return logger |
139 | |
140 | # ------------------------- |
141 | # 核心功能类 |
142 | # ------------------------- |
143 | class BackupManager: |
144 | def __init__(self, config: BackupConfig): |
145 | self.config = config |
146 | self.logger = setup_logger(config.log_file) |
147 | self.s3_client = self._init_s3_client() |
148 | |
149 | def _init_s3_client(self): |
150 | return boto3.client( |
151 | 's3', |
152 | endpoint_url=self.config.s3_endpoint, |
153 | aws_access_key_id=self.config.s3_access_key, |
154 | aws_secret_access_key=self.config.s3_secret_key, |
155 | region_name='cn-sy1', |
156 | config=boto3.session.Config(signature_version='s3v4') |
157 | ) |
158 | |
159 | def check_prerequisites(self) -> bool: |
160 | try: |
161 | os.makedirs(self.config.backup_dir, exist_ok=True) |
162 | # 添加对备份目录可写性的检查 |
163 | test_file = os.path.join(self.config.backup_dir, ".test") |
164 | with open(test_file, "w") as f: |
165 | f.write("test") |
166 | os.remove(test_file) |
167 | |
168 | subprocess.run(["pg_dump", "--version"], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
169 | return True |
170 | except Exception as e: |
171 | self.logger.error(f"前置检查失败: {str(e)}") |
172 | return False |
173 | |
174 | def create_compressed_backup(self) -> str: |
175 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
176 | base_name = f"{self.config.db_name}_backup_{timestamp}" |
177 | sql_path = os.path.join(self.config.backup_dir, base_name + ".sql") |
178 | gz_path = sql_path + ".gz" |
179 | |
180 | try: |
181 | # 检查关键参数是否为空 |
182 | if not all([self.config.db_name, self.config.db_user, self.config.db_password, |
183 | self.config.pg_host, str(self.config.pg_port)]): |
184 | raise ValueError("数据库连接参数不完整") |
185 | |
186 | env = os.environ.copy() |
187 | env["PGPASSWORD"] = self.config.db_password |
188 | pg_cmd = [ |
189 | "pg_dump", |
190 | f"--dbname={self.config.db_name}", |
191 | f"--host={self.config.pg_host}", |
192 | f"--port={self.config.pg_port}", |
193 | f"--username={self.config.db_user}", |
194 | f"--file={sql_path}" |
195 | ] |
196 | if self.config.use_ssl: |
197 | pg_cmd.append("--ssl-mode=require") |
198 | |
199 | # 添加详细的日志记录 |
200 | self.logger.debug(f"执行pg_dump命令: {' '.join(pg_cmd)}") |
201 | |
202 | # 添加备份开始提示 |
203 | print(f"[*] 开始数据库备份: {self.config.db_name}") |
204 | |
205 | # 执行pg_dump并显示进度 |
206 | with tqdm(desc="数据库导出", unit="B", unit_scale=True) as pbar: |
207 | def pg_dump_progress(line): |
208 | if "Dumping" in line: |
209 | pbar.set_postfix_str(line.strip()) |
210 | pbar.update(0) # 更新进度条但不增加计数 |
211 | |
212 | result = subprocess.run( |
213 | pg_cmd, |
214 | env=env, |
215 | capture_output=True, |
216 | text=True, |
217 | bufsize=1, |
218 | universal_newlines=True |
219 | ) |
220 | |
221 | if result.returncode != 0: |
222 | error_msg = result.stderr.strip() |
223 | self.logger.error(f"pg_dump详细错误: {error_msg}") |
224 | raise RuntimeError(f"pg_dump失败: {error_msg}") |
225 | |
226 | # 添加压缩进度显示 |
227 | print(f"[*] 正在压缩备份文件: {os.path.basename(gz_path)}") |
228 | file_size = os.path.getsize(sql_path) |
229 | with tqdm(total=file_size, desc="压缩进度", unit="B", unit_scale=True) as pbar: |
230 | with open(sql_path, "rb") as f_in, gzip.open(gz_path, "wb", self.config.compress_level) as f_out: |
231 | shutil.copyfileobj(f_in, f_out, length=1024*1024) # 1MB chunks |
232 | pbar.update(f_in.tell()) |
233 | |
234 | os.remove(sql_path) |
235 | print(f"[✓] 备份文件已创建: {os.path.basename(gz_path)}") |
236 | return gz_path |
237 | |
238 | except Exception as e: |
239 | self._cleanup_temp_files([sql_path, gz_path]) |
240 | self.logger.error(f"备份创建失败: {str(e)}") |
241 | raise |
242 | |
243 | @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10), retry=retry_if_exception_type(ClientError)) |
244 | def upload_to_s3(self, file_path: str) -> None: |
245 | file_name = os.path.basename(file_path) |
246 | print(f"[*] 开始上传到S3: {file_name}") |
247 | |
248 | transfer_config = TransferConfig(multipart_threshold=10*1024**2, max_concurrency=10) |
249 | |
250 | with open(file_path, "rb") as f, tqdm( |
251 | total=os.path.getsize(file_path), |
252 | unit="B", unit_scale=True, |
253 | desc="上传进度", |
254 | leave=True # 修改为True以保留进度条 |
255 | ) as pbar: |
256 | |
257 | def progress(bytes_transferred): |
258 | pbar.update(bytes_transferred) |
259 | |
260 | self.s3_client.upload_fileobj(f, self.config.s3_bucket, file_name, Config=transfer_config, Callback=progress) |
261 | |
262 | print(f"[✓] 上传完成: {file_name}") |
263 | print(f" S3位置: {self.config.s3_bucket}/{file_name}") |
264 | |
265 | def clean_old_backups(self) -> None: |
266 | cutoff = datetime.now() - timedelta(days=self.config.keep_days) |
267 | for file in os.listdir(self.config.backup_dir): |
268 | if file.endswith(".gz"): |
269 | path = os.path.join(self.config.backup_dir, file) |
270 | mtime = datetime.fromtimestamp(os.path.getmtime(path)) |
271 | if mtime < cutoff: |
272 | try: |
273 | os.remove(path) |
274 | self.logger.info(f"删除过期备份: {file}") |
275 | except Exception as e: |
276 | self.logger.warning(f"删除失败: {file} - {str(e)}") |
277 | |
278 | @staticmethod |
279 | def _cleanup_temp_files(files: List[str]) -> None: |
280 | for f in files: |
281 | if os.path.exists(f): |
282 | try: |
283 | os.remove(f) |
284 | except: |
285 | pass |
286 | |
287 | # ------------------------- |
288 | # 命令行接口 |
289 | # ------------------------- |
290 | def parse_arguments() -> argparse.Namespace: |
291 | parser = argparse.ArgumentParser(description="PostgreSQL 备份到 S3 工具") |
292 | parser.add_argument("-c", "--config", default="pg_backup.yaml", help="配置文件路径(默认: pg_backup.yaml)") |
293 | return parser.parse_args() |
294 | |
295 | # ------------------------- |
296 | # 主流程 |
297 | # ------------------------- |
298 | def check_dependencies(): |
299 | """检查所有依赖是否已安装""" |
300 | try: |
301 | import boto3 |
302 | import yaml |
303 | import tqdm |
304 | import tenacity |
305 | import dotenv |
306 | except ImportError as e: |
307 | print(f"[!] 缺少依赖: {str(e)}") |
308 | print("请运行: pip install -r requirements.txt") |
309 | raise SystemExit(1) |
310 | |
311 | # 在main()函数开头调用 |
312 | def main(): |
313 | check_dependencies() |
314 | args = parse_arguments() |
315 | config_path = args.config |
316 | |
317 | try: |
318 | config = load_or_create_config(config_path) |
319 | manager = BackupManager(config) |
320 | |
321 | print("[*] 开始备份流程") |
322 | if not manager.check_prerequisites(): |
323 | raise SystemExit("[!] 前置检查失败,流程终止") |
324 | |
325 | backup_path = manager.create_compressed_backup() |
326 | manager.upload_to_s3(backup_path) |
327 | manager.clean_old_backups() |
328 | os.remove(backup_path) # 清理本地文件 |
329 | |
330 | print("\n[✓] 备份流程成功完成") |
331 | |
332 | except SystemExit: |
333 | raise |
334 | except Exception as e: |
335 | manager.logger.error(f"备份失败: {str(e)}", exc_info=True) |
336 | print(f"\n[!] 备份失败: {str(e)}") |
337 | raise SystemExit(1) |
338 | |
339 | if __name__ == "__main__": |
340 | main() |
341 |
pg_restore_s3.py
· 9.8 KiB · Python
原始文件
#!/usr/bin/env python3
"""
PostgreSQL 恢复脚本(优化版)
此脚本用于恢复由 pg_backup_s3.py 备份到 S3 的数据库
支持从配置文件读取信息、错误重试、进度显示和详细日志记录
"""
import os
import boto3
import gzip
import subprocess
import shutil
from datetime import datetime
import logging
import yaml
from botocore.exceptions import ClientError
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from tqdm import tqdm
import argparse
# 配置文件默认路径
DEFAULT_CONFIG_PATH = "pg_backup.yaml"
# 日志设置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[
logging.StreamHandler(),
logging.FileHandler('/var/log/pg_restore.log')
]
)
logger = logging.getLogger('PG_Restore')
def print_step(message):
print(f"→ {message}")
def load_config(config_path):
"""
从 YAML 配置文件中加载配置信息
"""
if not os.path.exists(config_path):
logger.error(f"配置文件 {config_path} 不存在,请检查路径。")
raise FileNotFoundError(f"配置文件 {config_path} 不存在。")
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config
def get_s3_client(config):
"""
创建 S3 客户端
"""
return boto3.client(
's3',
endpoint_url=config['s3_endpoint'],
aws_access_key_id=config['s3_access_key'],
aws_secret_access_key=config['s3_secret_key'],
region_name='cn-sy1',
config=boto3.session.Config(signature_version='s3v4')
)
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10), retry=retry_if_exception_type(ClientError))
def list_backup_files(config):
"""
列出 S3 中的备份文件,并按时间倒序排列
"""
try:
s3 = get_s3_client(config)
response = s3.list_objects_v2(Bucket=config['s3_bucket'])
if 'Contents' not in response:
print_step("S3 桶中没有找到备份文件")
return []
files = [obj['Key'] for obj in response['Contents'] if obj['Key'].endswith('.gz')]
files.sort(reverse=True) # 按时间倒序排列
if not files:
print_step("没有找到 .gz 格式的备份文件")
return []
return files
except Exception as e:
logger.error(f"获取备份列表失败: {str(e)}")
raise
class DownloadProgressPercentage:
"""
下载进度显示
"""
def __init__(self, filename, total_size):
self._filename = filename
self._size = total_size
self._seen_so_far = 0
self._pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"下载 {filename}", leave=False)
def __call__(self, bytes_amount):
self._seen_so_far += bytes_amount
self._pbar.update(bytes_amount)
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10), retry=retry_if_exception_type(ClientError))
def download_from_s3(config, file_name):
"""
从 S3 下载备份文件
"""
try:
restore_dir = config.get('restore_dir', '/tmp/pg_restores')
os.makedirs(restore_dir, exist_ok=True)
local_path = os.path.join(restore_dir, file_name)
s3 = get_s3_client(config)
print_step(f"正在下载 {file_name}...")
# 获取文件大小用于进度显示
file_size = s3.head_object(Bucket=config['s3_bucket'], Key=file_name)['ContentLength']
s3.download_file(
Bucket=config['s3_bucket'],
Key=file_name,
Filename=local_path,
Callback=DownloadProgressPercentage(file_name, file_size)
)
print() # 换行
return local_path
except Exception as e:
logger.error(f"下载备份文件失败: {str(e)}")
raise
def decompress_file(compressed_path):
"""
解压备份文件
"""
try:
print_step("正在解压备份文件...")
decompressed_path = compressed_path[:-3] # 去掉 .gz 后缀
with gzip.open(compressed_path, 'rb') as f_in:
with open(decompressed_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
return decompressed_path
except Exception as e:
logger.error(f"解压备份文件失败: {str(e)}")
raise
def restore_database(config, sql_file):
"""
执行数据库恢复
"""
try:
# 让用户选择恢复模式
print("\n请选择恢复模式:")
print("1. 完全恢复 (先清空数据库,再恢复)")
print("2. 追加恢复 (保留现有数据,只添加备份数据)")
while True:
try:
mode = int(input("请输入选择(1或2): "))
if mode in [1, 2]:
break
print("输入无效,请输入 1 或 2")
except ValueError:
print("请输入有效的数字")
env = os.environ.copy()
env['PGPASSWORD'] = config['db_password']
# 完全恢复模式
if mode == 1:
print_step("正在准备完全恢复...")
temp_db = f"{config['db_name']}_temp"
# 0. 先检查并删除已存在的临时数据库
print_step("正在清理可能存在的临时数据库...")
drop_temp_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"DROP DATABASE IF EXISTS {temp_db};"
]
subprocess.run(drop_temp_cmd, check=True)
# 1. 创建临时数据库
print_step("正在创建临时数据库...")
create_temp_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"CREATE DATABASE {temp_db} WITH OWNER {config['db_user']} ENCODING 'UTF8';"
]
subprocess.run(create_temp_cmd, check=True)
# 2. 将备份恢复到临时数据库
print_step("正在恢复数据到临时数据库...")
restore_temp_cmd = [
'psql',
'-U', config['db_user'],
'-h', 'localhost',
'-d', temp_db,
'-f', sql_file
]
subprocess.run(restore_temp_cmd, env=env, check=True)
# 3. 终止所有连接到原数据库的会话
print_step("正在终止原数据库连接...")
terminate_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = '{config['db_name']}';"
]
subprocess.run(terminate_cmd, check=True)
# 4. 删除原数据库
print_step("正在清理原数据库...")
drop_orig_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"DROP DATABASE IF EXISTS {config['db_name']};"
]
subprocess.run(drop_orig_cmd, check=True)
# 5. 重命名临时数据库
print_step("正在完成恢复...")
rename_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"ALTER DATABASE {temp_db} RENAME TO {config['db_name']};"
]
subprocess.run(rename_cmd, check=True)
# 普通恢复操作
print_step("正在恢复数据库...")
restore_cmd = [
'psql',
'-U', config['db_user'],
'-h', 'localhost',
'-d', config['db_name'],
'-f', sql_file
]
result = subprocess.run(
restore_cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
if result.returncode != 0:
raise Exception(f"恢复失败: {result.stderr.strip()}")
print_step("数据库恢复成功")
except Exception as e:
logger.error(f"数据库恢复失败: {str(e)}")
raise
def cleanup(file_path):
"""
清理临时文件
"""
try:
if os.path.exists(file_path):
os.remove(file_path)
except Exception as e:
logger.warning(f"清理文件失败: {str(e)}")
def main():
parser = argparse.ArgumentParser(description="PostgreSQL 恢复脚本")
parser.add_argument("-c", "--config", default=DEFAULT_CONFIG_PATH, help="配置文件路径")
args = parser.parse_args()
print("\n" + "=" * 50)
print("PostgreSQL 恢复脚本")
print("=" * 50 + "\n")
try:
config = load_config(args.config)
# 列出备份文件
backup_files = list_backup_files(config)
if not backup_files:
return
# 显示备份文件列表
print("\n可用的备份文件:")
for i, file in enumerate(backup_files, 1):
print(f"{i}. {file}")
# 选择要恢复的备份
while True:
try:
choice = int(input("\n请输入要恢复的备份编号: "))
if 1 <= choice <= len(backup_files):
selected_file = backup_files[choice - 1]
break
print("输入无效,请重新输入")
except ValueError:
print("请输入有效的数字")
# 下载并恢复
compressed_path = download_from_s3(config, selected_file)
sql_path = decompress_file(compressed_path)
restore_database(config, sql_path)
except Exception as e:
print_step(f"[错误] {str(e)}")
finally:
# 清理临时文件
if 'compressed_path' in locals():
cleanup(compressed_path)
if 'sql_path' in locals():
cleanup(sql_path)
print("\n[操作完成]")
if __name__ == "__main__":
main()
1 | #!/usr/bin/env python3 |
2 | """ |
3 | PostgreSQL 恢复脚本(优化版) |
4 | 此脚本用于恢复由 pg_backup_s3.py 备份到 S3 的数据库 |
5 | 支持从配置文件读取信息、错误重试、进度显示和详细日志记录 |
6 | """ |
7 | |
8 | import os |
9 | import boto3 |
10 | import gzip |
11 | import subprocess |
12 | import shutil |
13 | from datetime import datetime |
14 | import logging |
15 | import yaml |
16 | from botocore.exceptions import ClientError |
17 | from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type |
18 | from tqdm import tqdm |
19 | import argparse |
20 | |
21 | |
22 | # 配置文件默认路径 |
23 | DEFAULT_CONFIG_PATH = "pg_backup.yaml" |
24 | |
25 | # 日志设置 |
26 | logging.basicConfig( |
27 | level=logging.INFO, |
28 | format='%(asctime)s - %(levelname)s - %(message)s', |
29 | datefmt='%Y-%m-%d %H:%M:%S', |
30 | handlers=[ |
31 | logging.StreamHandler(), |
32 | logging.FileHandler('/var/log/pg_restore.log') |
33 | ] |
34 | ) |
35 | logger = logging.getLogger('PG_Restore') |
36 | |
37 | |
38 | def print_step(message): |
39 | print(f"→ {message}") |
40 | |
41 | |
42 | def load_config(config_path): |
43 | """ |
44 | 从 YAML 配置文件中加载配置信息 |
45 | """ |
46 | if not os.path.exists(config_path): |
47 | logger.error(f"配置文件 {config_path} 不存在,请检查路径。") |
48 | raise FileNotFoundError(f"配置文件 {config_path} 不存在。") |
49 | with open(config_path, 'r') as f: |
50 | config = yaml.safe_load(f) |
51 | return config |
52 | |
53 | |
54 | def get_s3_client(config): |
55 | """ |
56 | 创建 S3 客户端 |
57 | """ |
58 | return boto3.client( |
59 | 's3', |
60 | endpoint_url=config['s3_endpoint'], |
61 | aws_access_key_id=config['s3_access_key'], |
62 | aws_secret_access_key=config['s3_secret_key'], |
63 | region_name='cn-sy1', |
64 | config=boto3.session.Config(signature_version='s3v4') |
65 | ) |
66 | |
67 | |
68 | @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10), retry=retry_if_exception_type(ClientError)) |
69 | def list_backup_files(config): |
70 | """ |
71 | 列出 S3 中的备份文件,并按时间倒序排列 |
72 | """ |
73 | try: |
74 | s3 = get_s3_client(config) |
75 | response = s3.list_objects_v2(Bucket=config['s3_bucket']) |
76 | if 'Contents' not in response: |
77 | print_step("S3 桶中没有找到备份文件") |
78 | return [] |
79 | files = [obj['Key'] for obj in response['Contents'] if obj['Key'].endswith('.gz')] |
80 | files.sort(reverse=True) # 按时间倒序排列 |
81 | if not files: |
82 | print_step("没有找到 .gz 格式的备份文件") |
83 | return [] |
84 | return files |
85 | except Exception as e: |
86 | logger.error(f"获取备份列表失败: {str(e)}") |
87 | raise |
88 | |
89 | |
90 | class DownloadProgressPercentage: |
91 | """ |
92 | 下载进度显示 |
93 | """ |
94 | |
95 | def __init__(self, filename, total_size): |
96 | self._filename = filename |
97 | self._size = total_size |
98 | self._seen_so_far = 0 |
99 | self._pbar = tqdm(total=total_size, unit='B', unit_scale=True, desc=f"下载 {filename}", leave=False) |
100 | |
101 | def __call__(self, bytes_amount): |
102 | self._seen_so_far += bytes_amount |
103 | self._pbar.update(bytes_amount) |
104 | |
105 | |
106 | @retry(stop=stop_after_attempt(3), wait=wait_exponential(min=2, max=10), retry=retry_if_exception_type(ClientError)) |
107 | def download_from_s3(config, file_name): |
108 | """ |
109 | 从 S3 下载备份文件 |
110 | """ |
111 | try: |
112 | restore_dir = config.get('restore_dir', '/tmp/pg_restores') |
113 | os.makedirs(restore_dir, exist_ok=True) |
114 | local_path = os.path.join(restore_dir, file_name) |
115 | s3 = get_s3_client(config) |
116 | print_step(f"正在下载 {file_name}...") |
117 | # 获取文件大小用于进度显示 |
118 | file_size = s3.head_object(Bucket=config['s3_bucket'], Key=file_name)['ContentLength'] |
119 | s3.download_file( |
120 | Bucket=config['s3_bucket'], |
121 | Key=file_name, |
122 | Filename=local_path, |
123 | Callback=DownloadProgressPercentage(file_name, file_size) |
124 | ) |
125 | print() # 换行 |
126 | return local_path |
127 | except Exception as e: |
128 | logger.error(f"下载备份文件失败: {str(e)}") |
129 | raise |
130 | |
131 | |
132 | def decompress_file(compressed_path): |
133 | """ |
134 | 解压备份文件 |
135 | """ |
136 | try: |
137 | print_step("正在解压备份文件...") |
138 | decompressed_path = compressed_path[:-3] # 去掉 .gz 后缀 |
139 | with gzip.open(compressed_path, 'rb') as f_in: |
140 | with open(decompressed_path, 'wb') as f_out: |
141 | shutil.copyfileobj(f_in, f_out) |
142 | return decompressed_path |
143 | except Exception as e: |
144 | logger.error(f"解压备份文件失败: {str(e)}") |
145 | raise |
146 | |
147 | |
148 | def restore_database(config, sql_file): |
149 | """ |
150 | 执行数据库恢复 |
151 | """ |
152 | try: |
153 | # 让用户选择恢复模式 |
154 | print("\n请选择恢复模式:") |
155 | print("1. 完全恢复 (先清空数据库,再恢复)") |
156 | print("2. 追加恢复 (保留现有数据,只添加备份数据)") |
157 | while True: |
158 | try: |
159 | mode = int(input("请输入选择(1或2): ")) |
160 | if mode in [1, 2]: |
161 | break |
162 | print("输入无效,请输入 1 或 2") |
163 | except ValueError: |
164 | print("请输入有效的数字") |
165 | |
166 | env = os.environ.copy() |
167 | env['PGPASSWORD'] = config['db_password'] |
168 | |
169 | # 完全恢复模式 |
170 | if mode == 1: |
171 | print_step("正在准备完全恢复...") |
172 | temp_db = f"{config['db_name']}_temp" |
173 | |
174 | # 0. 先检查并删除已存在的临时数据库 |
175 | print_step("正在清理可能存在的临时数据库...") |
176 | drop_temp_cmd = [ |
177 | 'sudo', '-u', 'postgres', |
178 | 'psql', |
179 | '-c', f"DROP DATABASE IF EXISTS {temp_db};" |
180 | ] |
181 | subprocess.run(drop_temp_cmd, check=True) |
182 | |
183 | # 1. 创建临时数据库 |
184 | print_step("正在创建临时数据库...") |
185 | create_temp_cmd = [ |
186 | 'sudo', '-u', 'postgres', |
187 | 'psql', |
188 | '-c', f"CREATE DATABASE {temp_db} WITH OWNER {config['db_user']} ENCODING 'UTF8';" |
189 | ] |
190 | subprocess.run(create_temp_cmd, check=True) |
191 | |
192 | # 2. 将备份恢复到临时数据库 |
193 | print_step("正在恢复数据到临时数据库...") |
194 | restore_temp_cmd = [ |
195 | 'psql', |
196 | '-U', config['db_user'], |
197 | '-h', 'localhost', |
198 | '-d', temp_db, |
199 | '-f', sql_file |
200 | ] |
201 | subprocess.run(restore_temp_cmd, env=env, check=True) |
202 | |
203 | # 3. 终止所有连接到原数据库的会话 |
204 | print_step("正在终止原数据库连接...") |
205 | terminate_cmd = [ |
206 | 'sudo', '-u', 'postgres', |
207 | 'psql', |
208 | '-c', f"SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = '{config['db_name']}';" |
209 | ] |
210 | subprocess.run(terminate_cmd, check=True) |
211 | |
212 | # 4. 删除原数据库 |
213 | print_step("正在清理原数据库...") |
214 | drop_orig_cmd = [ |
215 | 'sudo', '-u', 'postgres', |
216 | 'psql', |
217 | '-c', f"DROP DATABASE IF EXISTS {config['db_name']};" |
218 | ] |
219 | subprocess.run(drop_orig_cmd, check=True) |
220 | |
221 | # 5. 重命名临时数据库 |
222 | print_step("正在完成恢复...") |
223 | rename_cmd = [ |
224 | 'sudo', '-u', 'postgres', |
225 | 'psql', |
226 | '-c', f"ALTER DATABASE {temp_db} RENAME TO {config['db_name']};" |
227 | ] |
228 | subprocess.run(rename_cmd, check=True) |
229 | |
230 | # 普通恢复操作 |
231 | print_step("正在恢复数据库...") |
232 | restore_cmd = [ |
233 | 'psql', |
234 | '-U', config['db_user'], |
235 | '-h', 'localhost', |
236 | '-d', config['db_name'], |
237 | '-f', sql_file |
238 | ] |
239 | result = subprocess.run( |
240 | restore_cmd, |
241 | env=env, |
242 | stdout=subprocess.PIPE, |
243 | stderr=subprocess.PIPE, |
244 | text=True |
245 | ) |
246 | |
247 | if result.returncode != 0: |
248 | raise Exception(f"恢复失败: {result.stderr.strip()}") |
249 | |
250 | print_step("数据库恢复成功") |
251 | |
252 | except Exception as e: |
253 | logger.error(f"数据库恢复失败: {str(e)}") |
254 | raise |
255 | |
256 | |
257 | def cleanup(file_path): |
258 | """ |
259 | 清理临时文件 |
260 | """ |
261 | try: |
262 | if os.path.exists(file_path): |
263 | os.remove(file_path) |
264 | except Exception as e: |
265 | logger.warning(f"清理文件失败: {str(e)}") |
266 | |
267 | |
268 | def main(): |
269 | parser = argparse.ArgumentParser(description="PostgreSQL 恢复脚本") |
270 | parser.add_argument("-c", "--config", default=DEFAULT_CONFIG_PATH, help="配置文件路径") |
271 | args = parser.parse_args() |
272 | |
273 | print("\n" + "=" * 50) |
274 | print("PostgreSQL 恢复脚本") |
275 | print("=" * 50 + "\n") |
276 | |
277 | try: |
278 | config = load_config(args.config) |
279 | # 列出备份文件 |
280 | backup_files = list_backup_files(config) |
281 | if not backup_files: |
282 | return |
283 | |
284 | # 显示备份文件列表 |
285 | print("\n可用的备份文件:") |
286 | for i, file in enumerate(backup_files, 1): |
287 | print(f"{i}. {file}") |
288 | |
289 | # 选择要恢复的备份 |
290 | while True: |
291 | try: |
292 | choice = int(input("\n请输入要恢复的备份编号: ")) |
293 | if 1 <= choice <= len(backup_files): |
294 | selected_file = backup_files[choice - 1] |
295 | break |
296 | print("输入无效,请重新输入") |
297 | except ValueError: |
298 | print("请输入有效的数字") |
299 | |
300 | # 下载并恢复 |
301 | compressed_path = download_from_s3(config, selected_file) |
302 | sql_path = decompress_file(compressed_path) |
303 | restore_database(config, sql_path) |
304 | |
305 | except Exception as e: |
306 | print_step(f"[错误] {str(e)}") |
307 | finally: |
308 | # 清理临时文件 |
309 | if 'compressed_path' in locals(): |
310 | cleanup(compressed_path) |
311 | if 'sql_path' in locals(): |
312 | cleanup(sql_path) |
313 | |
314 | print("\n[操作完成]") |
315 | |
316 | |
317 | if __name__ == "__main__": |
318 | main() |
requirements.txt
· 109 B · Text
原始文件
boto3>=1.26.0,<2.0.0
pyyaml>=6.0,<7.0
tenacity>=8.2.2,<9.0
tqdm>=4.65.0,<5.0
python-dotenv>=0.19.0,<1.0.0
1 | boto3>=1.26.0,<2.0.0 |
2 | pyyaml>=6.0,<7.0 |
3 | tenacity>=8.2.2,<9.0 |
4 | tqdm>=4.65.0,<5.0 |
5 | python-dotenv>=0.19.0,<1.0.0 |