git
· 0 B · Text
Неформатований
pg_backup_s3.py
· 5.2 KiB · Python
Неформатований
#!/usr/bin/env python3
#推荐自行搭配定时执行工具使用--测试测试
#此脚本功能为pg数据库备份至s3
#使用前请自行修改配置信息(请自行安装py环境及依赖)
#使用方法:
#1. 配置信息
#2. 检查前置条件
#3. 创建压缩备份
#4. 上传备份文件到S3
#5. 清理旧备份
#6. 日志记录
#7. 异常处理
# 配置信息
import os
import subprocess
import boto3
from botocore.exceptions import ClientError
from datetime import datetime
import logging
import gzip
import shutil
from boto3.s3.transfer import TransferConfig
# 配置信息 请自行修改
DB_NAME = 'database_name'
DB_USER = 'database_user'
DB_PASSWORD = 'database_password'
S3_ENDPOINT = '你的存储桶域端点'
S3_ACCESS_KEY = '你的存储桶访问ACCESS_KEY'
S3_SECRET_KEY = '你的存储桶访问SECRET_KEY'
S3_BUCKET = '你的存储桶名称'
BACKUP_DIR = '/tmp/pg_backups' # 备份文件存储目录
COMPRESS_LEVEL = 6 # 压缩级别 (0-9), 0为不压缩, 9为最大压缩,不懂不要修改
# 日志设置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[
logging.StreamHandler(),
logging.FileHandler('/var/log/pg_backup_compressed.log')
]
)
logger = logging.getLogger('PG_Backup_Compressed')
def print_step(message):
print(f"→ {message}")
def check_prerequisites():
"""检查前置条件"""
try:
os.makedirs(BACKUP_DIR, exist_ok=True)
test_file = os.path.join(BACKUP_DIR, '.test')
with open(test_file, 'w') as f:
f.write('test')
os.remove(test_file)
subprocess.run(['pg_dump', '--version'], check=True, capture_output=True)
return True
except Exception as e:
logger.error(f"前置条件检查失败: {str(e)}")
return False
def create_compressed_backup():
"""创建压缩备份"""
timestamp = datetime.now().strftime("%m%d%H%M")
sql_file = os.path.join(BACKUP_DIR, f"{DB_NAME}_backup_{timestamp}.sql")
gz_file = f"{sql_file}.gz"
try:
print_step("正在执行pg_dump...")
env = os.environ.copy()
env['PGPASSWORD'] = DB_PASSWORD
cmd = [
'pg_dump',
'-U', DB_USER,
'-h', 'localhost',
'-d', DB_NAME,
'-f', sql_file
]
result = subprocess.run(
cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
if result.returncode != 0:
raise Exception(f"pg_dump失败: {result.stderr.strip()}")
if not os.path.exists(sql_file):
raise Exception("SQL文件未生成")
print_step("正在压缩备份文件...")
with open(sql_file, 'rb') as f_in:
with gzip.open(gz_file, 'wb', compresslevel=COMPRESS_LEVEL) as f_out:
shutil.copyfileobj(f_in, f_out)
os.remove(sql_file)
return gz_file
except Exception as e:
for f in [sql_file, gz_file]:
if os.path.exists(f):
try:
os.remove(f)
except:
pass
raise
class ProgressPercentage:
"""上传进度显示"""
def __init__(self, filename):
self._filename = filename
self._size = float(os.path.getsize(filename))
self._seen_so_far = 0
def __call__(self, bytes_amount):
self._seen_so_far += bytes_amount
percentage = (self._seen_so_far / self._size) * 100
print(f"\r 上传进度: {percentage:.2f}% ({self._seen_so_far/1024/1024:.2f}MB)", end='')
def upload_to_s3(file_path):
"""上传到S3"""
try:
s3 = boto3.client(
's3',
endpoint_url=S3_ENDPOINT,
aws_access_key_id=S3_ACCESS_KEY,
aws_secret_access_key=S3_SECRET_KEY,
region_name='cn-sy1',
config=boto3.session.Config(
signature_version='s3v4'
)
)
transfer_config = TransferConfig(
multipart_threshold=1024*25,
max_concurrency=10,
multipart_chunksize=1024*25,
use_threads=True
)
file_name = os.path.basename(file_path)
print_step(f"正在上传 {file_name}...")
s3.upload_file(
file_path,
S3_BUCKET,
file_name,
Config=transfer_config,
Callback=ProgressPercentage(file_path)
)
return True
except Exception as e:
raise
def main():
print("\n" + "="*50)
print("PostgreSQL 压缩备份脚本")
print("="*50 + "\n")
try:
if not check_prerequisites():
raise Exception("前置条件检查未通过")
backup_file = create_compressed_backup()
if upload_to_s3(backup_file):
os.remove(backup_file)
print_step("上传成功,已清理本地文件")
except Exception as e:
logger.error(f"备份失败: {str(e)}")
print_step(f"[错误] {str(e)}")
finally:
print("\n[操作完成]")
if __name__ == "__main__":
main()
1 | #!/usr/bin/env python3 |
2 | #推荐自行搭配定时执行工具使用--测试测试 |
3 | #此脚本功能为pg数据库备份至s3 |
4 | #使用前请自行修改配置信息(请自行安装py环境及依赖) |
5 | #使用方法: |
6 | #1. 配置信息 |
7 | #2. 检查前置条件 |
8 | #3. 创建压缩备份 |
9 | #4. 上传备份文件到S3 |
10 | #5. 清理旧备份 |
11 | #6. 日志记录 |
12 | #7. 异常处理 |
13 | |
14 | # 配置信息 |
15 | import os |
16 | import subprocess |
17 | import boto3 |
18 | from botocore.exceptions import ClientError |
19 | from datetime import datetime |
20 | import logging |
21 | import gzip |
22 | import shutil |
23 | from boto3.s3.transfer import TransferConfig |
24 | |
25 | # 配置信息 请自行修改 |
26 | DB_NAME = 'database_name' |
27 | DB_USER = 'database_user' |
28 | DB_PASSWORD = 'database_password' |
29 | S3_ENDPOINT = '你的存储桶域端点' |
30 | S3_ACCESS_KEY = '你的存储桶访问ACCESS_KEY' |
31 | S3_SECRET_KEY = '你的存储桶访问SECRET_KEY' |
32 | S3_BUCKET = '你的存储桶名称' |
33 | BACKUP_DIR = '/tmp/pg_backups' # 备份文件存储目录 |
34 | COMPRESS_LEVEL = 6 # 压缩级别 (0-9), 0为不压缩, 9为最大压缩,不懂不要修改 |
35 | |
36 | # 日志设置 |
37 | logging.basicConfig( |
38 | level=logging.INFO, |
39 | format='%(asctime)s - %(levelname)s - %(message)s', |
40 | datefmt='%Y-%m-%d %H:%M:%S', |
41 | handlers=[ |
42 | logging.StreamHandler(), |
43 | logging.FileHandler('/var/log/pg_backup_compressed.log') |
44 | ] |
45 | ) |
46 | logger = logging.getLogger('PG_Backup_Compressed') |
47 | |
48 | def print_step(message): |
49 | print(f"→ {message}") |
50 | |
51 | def check_prerequisites(): |
52 | """检查前置条件""" |
53 | try: |
54 | os.makedirs(BACKUP_DIR, exist_ok=True) |
55 | test_file = os.path.join(BACKUP_DIR, '.test') |
56 | with open(test_file, 'w') as f: |
57 | f.write('test') |
58 | os.remove(test_file) |
59 | subprocess.run(['pg_dump', '--version'], check=True, capture_output=True) |
60 | return True |
61 | except Exception as e: |
62 | logger.error(f"前置条件检查失败: {str(e)}") |
63 | return False |
64 | |
65 | def create_compressed_backup(): |
66 | """创建压缩备份""" |
67 | timestamp = datetime.now().strftime("%m%d%H%M") |
68 | sql_file = os.path.join(BACKUP_DIR, f"{DB_NAME}_backup_{timestamp}.sql") |
69 | gz_file = f"{sql_file}.gz" |
70 | |
71 | try: |
72 | print_step("正在执行pg_dump...") |
73 | env = os.environ.copy() |
74 | env['PGPASSWORD'] = DB_PASSWORD |
75 | cmd = [ |
76 | 'pg_dump', |
77 | '-U', DB_USER, |
78 | '-h', 'localhost', |
79 | '-d', DB_NAME, |
80 | '-f', sql_file |
81 | ] |
82 | |
83 | result = subprocess.run( |
84 | cmd, |
85 | env=env, |
86 | stdout=subprocess.PIPE, |
87 | stderr=subprocess.PIPE, |
88 | text=True |
89 | ) |
90 | |
91 | if result.returncode != 0: |
92 | raise Exception(f"pg_dump失败: {result.stderr.strip()}") |
93 | |
94 | if not os.path.exists(sql_file): |
95 | raise Exception("SQL文件未生成") |
96 | |
97 | print_step("正在压缩备份文件...") |
98 | with open(sql_file, 'rb') as f_in: |
99 | with gzip.open(gz_file, 'wb', compresslevel=COMPRESS_LEVEL) as f_out: |
100 | shutil.copyfileobj(f_in, f_out) |
101 | |
102 | os.remove(sql_file) |
103 | return gz_file |
104 | |
105 | except Exception as e: |
106 | for f in [sql_file, gz_file]: |
107 | if os.path.exists(f): |
108 | try: |
109 | os.remove(f) |
110 | except: |
111 | pass |
112 | raise |
113 | |
114 | class ProgressPercentage: |
115 | """上传进度显示""" |
116 | def __init__(self, filename): |
117 | self._filename = filename |
118 | self._size = float(os.path.getsize(filename)) |
119 | self._seen_so_far = 0 |
120 | |
121 | def __call__(self, bytes_amount): |
122 | self._seen_so_far += bytes_amount |
123 | percentage = (self._seen_so_far / self._size) * 100 |
124 | print(f"\r 上传进度: {percentage:.2f}% ({self._seen_so_far/1024/1024:.2f}MB)", end='') |
125 | |
126 | def upload_to_s3(file_path): |
127 | """上传到S3""" |
128 | try: |
129 | s3 = boto3.client( |
130 | 's3', |
131 | endpoint_url=S3_ENDPOINT, |
132 | aws_access_key_id=S3_ACCESS_KEY, |
133 | aws_secret_access_key=S3_SECRET_KEY, |
134 | region_name='cn-sy1', |
135 | config=boto3.session.Config( |
136 | signature_version='s3v4' |
137 | ) |
138 | ) |
139 | |
140 | transfer_config = TransferConfig( |
141 | multipart_threshold=1024*25, |
142 | max_concurrency=10, |
143 | multipart_chunksize=1024*25, |
144 | use_threads=True |
145 | ) |
146 | |
147 | file_name = os.path.basename(file_path) |
148 | print_step(f"正在上传 {file_name}...") |
149 | |
150 | s3.upload_file( |
151 | file_path, |
152 | S3_BUCKET, |
153 | file_name, |
154 | Config=transfer_config, |
155 | Callback=ProgressPercentage(file_path) |
156 | ) |
157 | |
158 | return True |
159 | except Exception as e: |
160 | raise |
161 | |
162 | def main(): |
163 | print("\n" + "="*50) |
164 | print("PostgreSQL 压缩备份脚本") |
165 | print("="*50 + "\n") |
166 | |
167 | try: |
168 | if not check_prerequisites(): |
169 | raise Exception("前置条件检查未通过") |
170 | |
171 | backup_file = create_compressed_backup() |
172 | if upload_to_s3(backup_file): |
173 | os.remove(backup_file) |
174 | print_step("上传成功,已清理本地文件") |
175 | |
176 | except Exception as e: |
177 | logger.error(f"备份失败: {str(e)}") |
178 | print_step(f"[错误] {str(e)}") |
179 | finally: |
180 | print("\n[操作完成]") |
181 | |
182 | if __name__ == "__main__": |
183 | main() |
pg_restore_s3.py
· 9.0 KiB · Python
Неформатований
#!/usr/bin/env python3
#此脚本请搭配pg_backup_s3.py使用,用于恢复数据库
#使用方法:
#1. 配置信息
#2. 检查前置条件
#3. 列出S3中的备份文件
#4. 下载备份文件
#5. 解压备份文件
#6. 恢复数据库
#7. 清理临时文件
#8. 日志记录
#9. 异常处理
import os
import boto3
import gzip
import subprocess
import shutil
from datetime import datetime
import logging
from botocore.exceptions import ClientError
# 使用与备份脚本相同的配置
DB_NAME = 'database_name'
DB_USER = 'database_user'
DB_PASSWORD = 'database_password'
S3_ENDPOINT = '你的s3端点'
S3_ACCESS_KEY = '你的s3_access_key'
S3_SECRET_KEY = '你的s3_secret_key'
S3_BUCKET = '你的s3桶名'
RESTORE_DIR = '/tmp/pg_restores' # 恢复文件存储目录
# 日志设置
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[
logging.StreamHandler(),
logging.FileHandler('/var/log/pg_restore.log')
]
)
logger = logging.getLogger('PG_Restore')
def print_step(message):
print(f"→ {message}")
def get_s3_client():
"""创建S3客户端"""
return boto3.client(
's3',
endpoint_url=S3_ENDPOINT,
aws_access_key_id=S3_ACCESS_KEY,
aws_secret_access_key=S3_SECRET_KEY,
region_name='cn-sy1',
config=boto3.session.Config(signature_version='s3v4')
)
def list_backup_files():
"""列出S3中的备份文件"""
try:
s3 = get_s3_client()
response = s3.list_objects_v2(Bucket=S3_BUCKET)
if 'Contents' not in response:
print_step("S3桶中没有找到备份文件")
return []
files = [obj['Key'] for obj in response['Contents'] if obj['Key'].endswith('.gz')]
files.sort(reverse=True) # 按时间倒序排列
if not files:
print_step("没有找到.gz格式的备份文件")
return []
return files
except Exception as e:
logger.error(f"获取备份列表失败: {str(e)}")
raise
class DownloadProgressPercentage:
"""下载进度显示"""
def __init__(self, filename, total_size):
self._filename = filename
self._size = total_size
self._seen_so_far = 0
def __call__(self, bytes_amount):
self._seen_so_far += bytes_amount
percentage = (self._seen_so_far / self._size) * 100
print(f"\r 下载进度: {percentage:.2f}% ({self._seen_so_far/1024/1024:.2f}MB)", end='')
def download_from_s3(file_name):
"""从S3下载备份文件"""
try:
os.makedirs(RESTORE_DIR, exist_ok=True)
local_path = os.path.join(RESTORE_DIR, file_name)
s3 = get_s3_client()
print_step(f"正在下载 {file_name}...")
# 获取文件大小用于进度显示
file_size = s3.head_object(Bucket=S3_BUCKET, Key=file_name)['ContentLength']
s3.download_file(
Bucket=S3_BUCKET,
Key=file_name,
Filename=local_path,
Callback=DownloadProgressPercentage(file_name, file_size)
)
print() # 换行
return local_path
except Exception as e:
logger.error(f"下载备份文件失败: {str(e)}")
raise
def decompress_file(compressed_path):
"""解压备份文件"""
try:
print_step("正在解压备份文件...")
decompressed_path = compressed_path[:-3] # 去掉.gz后缀
with gzip.open(compressed_path, 'rb') as f_in:
with open(decompressed_path, 'wb') as f_out:
shutil.copyfileobj(f_in, f_out)
return decompressed_path
except Exception as e:
logger.error(f"解压备份文件失败: {str(e)}")
raise
def restore_database(sql_file):
"""执行数据库恢复"""
try:
# 让用户选择恢复模式
print("\n请选择恢复模式:")
print("1. 完全恢复 (先清空数据库,再恢复)")
print("2. 追加恢复 (保留现有数据,只添加备份数据)")
while True:
try:
mode = int(input("请输入选择(1或2): "))
if mode in [1, 2]:
break
print("输入无效,请输入1或2")
except ValueError:
print("请输入有效的数字")
env = os.environ.copy()
env['PGPASSWORD'] = DB_PASSWORD
# 完全恢复模式
if mode == 1:
print_step("正在准备完全恢复...")
temp_db = f"{DB_NAME}_temp"
# 0. 先检查并删除已存在的临时数据库
print_step("正在清理可能存在的临时数据库...")
drop_temp_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"DROP DATABASE IF EXISTS {temp_db};"
]
subprocess.run(drop_temp_cmd, check=True)
# 1. 创建临时数据库
print_step("正在创建临时数据库...")
create_temp_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"CREATE DATABASE {temp_db} WITH OWNER {DB_USER} ENCODING 'UTF8';"
]
subprocess.run(create_temp_cmd, check=True)
# 2. 将备份恢复到临时数据库
print_step("正在恢复数据到临时数据库...")
restore_temp_cmd = [
'psql',
'-U', DB_USER,
'-h', 'localhost',
'-d', temp_db,
'-f', sql_file
]
subprocess.run(restore_temp_cmd, env=env, check=True)
# 3. 终止所有连接到原数据库的会话
print_step("正在终止原数据库连接...")
terminate_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = '{DB_NAME}';"
]
subprocess.run(terminate_cmd, check=True)
# 4. 删除原数据库
print_step("正在清理原数据库...")
drop_orig_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"DROP DATABASE IF EXISTS {DB_NAME};"
]
subprocess.run(drop_orig_cmd, check=True)
# 5. 重命名临时数据库
print_step("正在完成恢复...")
rename_cmd = [
'sudo', '-u', 'postgres',
'psql',
'-c', f"ALTER DATABASE {temp_db} RENAME TO {DB_NAME};"
]
subprocess.run(rename_cmd, check=True)
# 普通恢复操作
print_step("正在恢复数据库...")
restore_cmd = [
'psql',
'-U', DB_USER,
'-h', 'localhost',
'-d', DB_NAME,
'-f', sql_file
]
result = subprocess.run(
restore_cmd,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
if result.returncode != 0:
raise Exception(f"恢复失败: {result.stderr.strip()}")
print_step("数据库恢复成功")
except Exception as e:
logger.error(f"数据库恢复失败: {str(e)}")
raise
def cleanup(file_path):
"""清理临时文件"""
try:
if os.path.exists(file_path):
os.remove(file_path)
except Exception as e:
logger.warning(f"清理文件失败: {str(e)}")
def main():
print("\n" + "="*50)
print("PostgreSQL 恢复脚本")
print("="*50 + "\n")
try:
# 列出备份文件
backup_files = list_backup_files()
if not backup_files:
return
# 显示备份文件列表
print("\n可用的备份文件:")
for i, file in enumerate(backup_files, 1):
print(f"{i}. {file}")
# 选择要恢复的备份
while True:
try:
choice = int(input("\n请输入要恢复的备份编号: "))
if 1 <= choice <= len(backup_files):
selected_file = backup_files[choice-1]
break
print("输入无效,请重新输入")
except ValueError:
print("请输入有效的数字")
# 下载并恢复
compressed_path = download_from_s3(selected_file)
sql_path = decompress_file(compressed_path)
restore_database(sql_path)
except Exception as e:
print_step(f"[错误] {str(e)}")
finally:
# 清理临时文件
if 'compressed_path' in locals():
cleanup(compressed_path)
if 'sql_path' in locals():
cleanup(sql_path)
print("\n[操作完成]")
if __name__ == "__main__":
main()
1 | #!/usr/bin/env python3 |
2 | #此脚本请搭配pg_backup_s3.py使用,用于恢复数据库 |
3 | #使用方法: |
4 | #1. 配置信息 |
5 | #2. 检查前置条件 |
6 | #3. 列出S3中的备份文件 |
7 | #4. 下载备份文件 |
8 | #5. 解压备份文件 |
9 | #6. 恢复数据库 |
10 | #7. 清理临时文件 |
11 | #8. 日志记录 |
12 | #9. 异常处理 |
13 | |
14 | import os |
15 | import boto3 |
16 | import gzip |
17 | import subprocess |
18 | import shutil |
19 | from datetime import datetime |
20 | import logging |
21 | from botocore.exceptions import ClientError |
22 | |
23 | # 使用与备份脚本相同的配置 |
24 | DB_NAME = 'database_name' |
25 | DB_USER = 'database_user' |
26 | DB_PASSWORD = 'database_password' |
27 | S3_ENDPOINT = '你的s3端点' |
28 | S3_ACCESS_KEY = '你的s3_access_key' |
29 | S3_SECRET_KEY = '你的s3_secret_key' |
30 | S3_BUCKET = '你的s3桶名' |
31 | RESTORE_DIR = '/tmp/pg_restores' # 恢复文件存储目录 |
32 | |
33 | # 日志设置 |
34 | logging.basicConfig( |
35 | level=logging.INFO, |
36 | format='%(asctime)s - %(levelname)s - %(message)s', |
37 | datefmt='%Y-%m-%d %H:%M:%S', |
38 | handlers=[ |
39 | logging.StreamHandler(), |
40 | logging.FileHandler('/var/log/pg_restore.log') |
41 | ] |
42 | ) |
43 | logger = logging.getLogger('PG_Restore') |
44 | |
45 | def print_step(message): |
46 | print(f"→ {message}") |
47 | |
48 | def get_s3_client(): |
49 | """创建S3客户端""" |
50 | return boto3.client( |
51 | 's3', |
52 | endpoint_url=S3_ENDPOINT, |
53 | aws_access_key_id=S3_ACCESS_KEY, |
54 | aws_secret_access_key=S3_SECRET_KEY, |
55 | region_name='cn-sy1', |
56 | config=boto3.session.Config(signature_version='s3v4') |
57 | ) |
58 | |
59 | def list_backup_files(): |
60 | """列出S3中的备份文件""" |
61 | try: |
62 | s3 = get_s3_client() |
63 | response = s3.list_objects_v2(Bucket=S3_BUCKET) |
64 | |
65 | if 'Contents' not in response: |
66 | print_step("S3桶中没有找到备份文件") |
67 | return [] |
68 | |
69 | files = [obj['Key'] for obj in response['Contents'] if obj['Key'].endswith('.gz')] |
70 | files.sort(reverse=True) # 按时间倒序排列 |
71 | |
72 | if not files: |
73 | print_step("没有找到.gz格式的备份文件") |
74 | return [] |
75 | |
76 | return files |
77 | |
78 | except Exception as e: |
79 | logger.error(f"获取备份列表失败: {str(e)}") |
80 | raise |
81 | |
82 | class DownloadProgressPercentage: |
83 | """下载进度显示""" |
84 | def __init__(self, filename, total_size): |
85 | self._filename = filename |
86 | self._size = total_size |
87 | self._seen_so_far = 0 |
88 | |
89 | def __call__(self, bytes_amount): |
90 | self._seen_so_far += bytes_amount |
91 | percentage = (self._seen_so_far / self._size) * 100 |
92 | print(f"\r 下载进度: {percentage:.2f}% ({self._seen_so_far/1024/1024:.2f}MB)", end='') |
93 | |
94 | def download_from_s3(file_name): |
95 | """从S3下载备份文件""" |
96 | try: |
97 | os.makedirs(RESTORE_DIR, exist_ok=True) |
98 | local_path = os.path.join(RESTORE_DIR, file_name) |
99 | |
100 | s3 = get_s3_client() |
101 | print_step(f"正在下载 {file_name}...") |
102 | |
103 | # 获取文件大小用于进度显示 |
104 | file_size = s3.head_object(Bucket=S3_BUCKET, Key=file_name)['ContentLength'] |
105 | |
106 | s3.download_file( |
107 | Bucket=S3_BUCKET, |
108 | Key=file_name, |
109 | Filename=local_path, |
110 | Callback=DownloadProgressPercentage(file_name, file_size) |
111 | ) |
112 | |
113 | print() # 换行 |
114 | return local_path |
115 | |
116 | except Exception as e: |
117 | logger.error(f"下载备份文件失败: {str(e)}") |
118 | raise |
119 | |
120 | def decompress_file(compressed_path): |
121 | """解压备份文件""" |
122 | try: |
123 | print_step("正在解压备份文件...") |
124 | decompressed_path = compressed_path[:-3] # 去掉.gz后缀 |
125 | |
126 | with gzip.open(compressed_path, 'rb') as f_in: |
127 | with open(decompressed_path, 'wb') as f_out: |
128 | shutil.copyfileobj(f_in, f_out) |
129 | |
130 | return decompressed_path |
131 | |
132 | except Exception as e: |
133 | logger.error(f"解压备份文件失败: {str(e)}") |
134 | raise |
135 | |
136 | def restore_database(sql_file): |
137 | """执行数据库恢复""" |
138 | try: |
139 | # 让用户选择恢复模式 |
140 | print("\n请选择恢复模式:") |
141 | print("1. 完全恢复 (先清空数据库,再恢复)") |
142 | print("2. 追加恢复 (保留现有数据,只添加备份数据)") |
143 | while True: |
144 | try: |
145 | mode = int(input("请输入选择(1或2): ")) |
146 | if mode in [1, 2]: |
147 | break |
148 | print("输入无效,请输入1或2") |
149 | except ValueError: |
150 | print("请输入有效的数字") |
151 | |
152 | env = os.environ.copy() |
153 | env['PGPASSWORD'] = DB_PASSWORD |
154 | |
155 | # 完全恢复模式 |
156 | if mode == 1: |
157 | print_step("正在准备完全恢复...") |
158 | temp_db = f"{DB_NAME}_temp" |
159 | |
160 | # 0. 先检查并删除已存在的临时数据库 |
161 | print_step("正在清理可能存在的临时数据库...") |
162 | drop_temp_cmd = [ |
163 | 'sudo', '-u', 'postgres', |
164 | 'psql', |
165 | '-c', f"DROP DATABASE IF EXISTS {temp_db};" |
166 | ] |
167 | subprocess.run(drop_temp_cmd, check=True) |
168 | |
169 | # 1. 创建临时数据库 |
170 | print_step("正在创建临时数据库...") |
171 | create_temp_cmd = [ |
172 | 'sudo', '-u', 'postgres', |
173 | 'psql', |
174 | '-c', f"CREATE DATABASE {temp_db} WITH OWNER {DB_USER} ENCODING 'UTF8';" |
175 | ] |
176 | subprocess.run(create_temp_cmd, check=True) |
177 | |
178 | # 2. 将备份恢复到临时数据库 |
179 | print_step("正在恢复数据到临时数据库...") |
180 | restore_temp_cmd = [ |
181 | 'psql', |
182 | '-U', DB_USER, |
183 | '-h', 'localhost', |
184 | '-d', temp_db, |
185 | '-f', sql_file |
186 | ] |
187 | subprocess.run(restore_temp_cmd, env=env, check=True) |
188 | |
189 | # 3. 终止所有连接到原数据库的会话 |
190 | print_step("正在终止原数据库连接...") |
191 | terminate_cmd = [ |
192 | 'sudo', '-u', 'postgres', |
193 | 'psql', |
194 | '-c', f"SELECT pg_terminate_backend(pg_stat_activity.pid) FROM pg_stat_activity WHERE pg_stat_activity.datname = '{DB_NAME}';" |
195 | ] |
196 | subprocess.run(terminate_cmd, check=True) |
197 | |
198 | # 4. 删除原数据库 |
199 | print_step("正在清理原数据库...") |
200 | drop_orig_cmd = [ |
201 | 'sudo', '-u', 'postgres', |
202 | 'psql', |
203 | '-c', f"DROP DATABASE IF EXISTS {DB_NAME};" |
204 | ] |
205 | subprocess.run(drop_orig_cmd, check=True) |
206 | |
207 | # 5. 重命名临时数据库 |
208 | print_step("正在完成恢复...") |
209 | rename_cmd = [ |
210 | 'sudo', '-u', 'postgres', |
211 | 'psql', |
212 | '-c', f"ALTER DATABASE {temp_db} RENAME TO {DB_NAME};" |
213 | ] |
214 | subprocess.run(rename_cmd, check=True) |
215 | |
216 | # 普通恢复操作 |
217 | print_step("正在恢复数据库...") |
218 | restore_cmd = [ |
219 | 'psql', |
220 | '-U', DB_USER, |
221 | '-h', 'localhost', |
222 | '-d', DB_NAME, |
223 | '-f', sql_file |
224 | ] |
225 | result = subprocess.run( |
226 | restore_cmd, |
227 | env=env, |
228 | stdout=subprocess.PIPE, |
229 | stderr=subprocess.PIPE, |
230 | text=True |
231 | ) |
232 | |
233 | if result.returncode != 0: |
234 | raise Exception(f"恢复失败: {result.stderr.strip()}") |
235 | |
236 | print_step("数据库恢复成功") |
237 | |
238 | except Exception as e: |
239 | logger.error(f"数据库恢复失败: {str(e)}") |
240 | raise |
241 | |
242 | def cleanup(file_path): |
243 | """清理临时文件""" |
244 | try: |
245 | if os.path.exists(file_path): |
246 | os.remove(file_path) |
247 | except Exception as e: |
248 | logger.warning(f"清理文件失败: {str(e)}") |
249 | |
250 | def main(): |
251 | print("\n" + "="*50) |
252 | print("PostgreSQL 恢复脚本") |
253 | print("="*50 + "\n") |
254 | |
255 | try: |
256 | # 列出备份文件 |
257 | backup_files = list_backup_files() |
258 | if not backup_files: |
259 | return |
260 | |
261 | # 显示备份文件列表 |
262 | print("\n可用的备份文件:") |
263 | for i, file in enumerate(backup_files, 1): |
264 | print(f"{i}. {file}") |
265 | |
266 | # 选择要恢复的备份 |
267 | while True: |
268 | try: |
269 | choice = int(input("\n请输入要恢复的备份编号: ")) |
270 | if 1 <= choice <= len(backup_files): |
271 | selected_file = backup_files[choice-1] |
272 | break |
273 | print("输入无效,请重新输入") |
274 | except ValueError: |
275 | print("请输入有效的数字") |
276 | |
277 | # 下载并恢复 |
278 | compressed_path = download_from_s3(selected_file) |
279 | sql_path = decompress_file(compressed_path) |
280 | restore_database(sql_path) |
281 | |
282 | except Exception as e: |
283 | print_step(f"[错误] {str(e)}") |
284 | finally: |
285 | # 清理临时文件 |
286 | if 'compressed_path' in locals(): |
287 | cleanup(compressed_path) |
288 | if 'sql_path' in locals(): |
289 | cleanup(sql_path) |
290 | |
291 | print("\n[操作完成]") |
292 | |
293 | if __name__ == "__main__": |
294 | main() |