Aurora support

parent 1876417f
......@@ -73,9 +73,9 @@ def create_kms_key(region, assumable_role_arn):
@dr.command()
@click.pass_context
@click.option("--take-snapshot", is_flag=True, help="TODO: Boolean, default false. Take a live snapshot now, or take the existing latest snapshot")
@click.option("--db-arns", required=False, help="Comma separated list of either DB names or ARNs to transfer")
@click.option("--service", required=False, help="The service to transfer backups for. Defaults to all (RDS, S3)")
def transfer(ctx, take_snapshot, db_arns, service):
@click.option("--db-names", required=False, help="Comma separated list of DB names to transfer")
@click.option("--service", type=click.Choice(['rds', 'aurora', 's3']), required=False, help="The service to transfer backups for. Defaults to all (RDS, S3)")
def transfer(ctx, take_snapshot, db_names, service):
"""
Backup [service] from owning account of [ctx.source_role_arn] to owning account
of [ctx.destination_role_arn].
......@@ -95,18 +95,45 @@ def transfer(ctx, take_snapshot, db_arns, service):
destination_kms_key = create_kms_key(region, destination_role_arn)
if service == 'rds':
if db_names:
db_names = [db_names.replace(' ','')]
else:
scanner = scan_resources_storage.ScanResources(region, source_role_arn)
db_names = db_names or scanner.scan_rds_instances()['db_names']
rds(
dry_run,
region,
source_role_arn,
destination_role_arn,
take_snapshot,
db_names,
source_kms_key,
destination_kms_key,
source_account,
destination_account)
if service == 'aurora':
if db_names:
db_names = [db_names.replace(' ','')]
else:
scanner = scan_resources_storage.ScanResources(region, source_role_arn)
db_names = db_names or scanner.scan_rds_aurora()['aurora_names']
rds(
dry_run,
region,
source_role_arn,
destination_role_arn,
take_snapshot,
db_arns,
db_names,
source_kms_key,
destination_kms_key,
source_account,
destination_account)
if service == 's3':
logging.info('TODO')
def rds(
dry_run,
......@@ -114,7 +141,7 @@ def rds(
source_role_arn,
destination_role_arn,
take_snapshot,
db_arns,
db_names,
source_kms_key,
destination_kms_key,
source_account,
......@@ -123,14 +150,8 @@ def rds(
Call the RDS class to transfer snapshots
"""
if db_arns:
db_arns = [db_arns.replace(' ','')]
else:
scanner = scan_resources_storage.ScanResources(region, source_role_arn)
db_arns = db_arns or scanner.scan_rds()['rds_arns']
logging.info("Will attempt to backup the following RDS instances, unless this is a dry run:")
logging.info(db_arns)
logging.info(db_names)
if dry_run:
exit(0)
......@@ -146,7 +167,7 @@ def rds(
rds.transfer_snapshot(
take_snapshot=take_snapshot,
db_arns=db_arns,
db_names=db_names,
source_account=source_account,
destination_account=destination_account
)
......@@ -17,7 +17,7 @@ method; transfer_snapshot()
from datetime import datetime
from operator import itemgetter
from akinaka.client.aws_client import AWS_Client
from akinaka.libs import helpers
from akinaka.libs import helpers, exceptions
import logging
import time
......@@ -26,12 +26,12 @@ aws_client = AWS_Client()
class TransferSnapshot():
def __init__(
self,
region,
source_role_arn,
destination_role_arn,
source_kms_key,
destination_kms_key
self,
region,
source_role_arn,
destination_role_arn,
source_kms_key,
destination_kms_key
):
self.region = region
......@@ -40,9 +40,9 @@ class TransferSnapshot():
self.source_kms_key = source_kms_key
self.destination_kms_key = destination_kms_key
def transfer_snapshot(self, take_snapshot, db_arns, source_account, destination_account):
def transfer_snapshot(self, take_snapshot, db_names, source_account, destination_account):
"""
For every DB in [db_arns], call methods to:
For every DB in [db_names], call methods to:
1. Either take a new snapshot (TODO), or use the latest automatically created one
2. Recrypt the snapshot with [self.source_kms_key]. This key must be shared between accounts
......@@ -50,38 +50,46 @@ class TransferSnapshot():
4. Copy it to self.destination_account with the [destination_kms_key]
"""
for arn in db_arns:
source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn)
for db_name in db_names:
source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn, valid_for=14400)
if take_snapshot:
source_snapshot = self.take_snapshot(source_rds_client, arn, self.source_kms_key)
source_snapshot = self.take_snapshot(source_rds_client, db_name, self.source_kms_key)
logging.info("Will now recrypt it with the shared key")
else:
source_snapshot = self.get_latest_snapshot(arn)
source_snapshot = self.get_latest_snapshot(db_name)
recrypted_snapshot = self.recrypt_snapshot(source_rds_client, source_snapshot, self.source_kms_key, source_account)
self.share_snapshot(recrypted_snapshot, destination_account)
destination_rds_client = aws_client.create_client('rds', self.region, self.destination_role_arn)
logging.info('The snapshot must now be recrypted and copied with a key available only to the destination account')
destination_rds_client = aws_client.create_client('rds', self.region, self.destination_role_arn, valid_for=14400)
self.recrypt_snapshot(destination_rds_client, recrypted_snapshot, self.destination_kms_key, destination_account)
def get_latest_snapshot(self, db_arn):
def get_latest_snapshot(self, db_name):
"""
Return the latest snapshot for [db_arn], where the ARN can also be the name of the DB
Return the latest snapshot for [db_name], where the ARN can also be the name of the DB
Note: You can only use the db_arn if you are in the account with the DB in it, else you
Note: You can only use the db_name if you are in the account with the DB in it, else you
must use the DB name
"""
source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn)
snapshots = source_rds_client.describe_db_snapshots(DBInstanceIdentifier=db_arn)['DBSnapshots']
source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn, valid_for=14400)
# TODO: Yuk
# https://stackoverflow.com/questions/59285540/rewrite-python-method-depending-on-condition
try:
snapshots = source_rds_client.describe_db_snapshots(DBInstanceIdentifier=db_name)['DBSnapshots']
if len(snapshots) == 0:
raise exceptions.AkinakaCriticalException("No snapshots found for {}. You'll need to take one first with --take-snapshot".format(db_name))
latest = sorted(snapshots, key=itemgetter('SnapshotCreateTime'))[0]
except KeyError:
logging.error("Couldn't get the latest snapshot, probably because it's still being made")
logging.info("Found automatic snapshot {}".format(latest['DBSnapshotIdentifier']))
logging.info("Using snapshot {}".format(latest['DBSnapshotIdentifier']))
except IndexError:
snapshots = source_rds_client.describe_db_cluster_snapshots(DBClusterIdentifier=db_name)['DBClusterSnapshots']
if len(snapshots) == 0:
raise exceptions.AkinakaCriticalException("No snapshots found for {}. You'll need to take one first with --take-snapshot".format(db_name))
latest = sorted(snapshots, key=itemgetter('SnapshotCreateTime'))[0]
logging.info("Using snapshot {}".format(latest['DBClusterSnapshotIdentifier']))
return latest
......@@ -97,76 +105,146 @@ class TransferSnapshot():
Recrypt a snapshot [snapshot] with the KMS key [kms_key]. Return the recrypted snapshot.
"""
new_snapshot_id = self.make_snapshot_name(snapshot['DBInstanceIdentifier'], destination_account)
# TODO: Yuk
# https://stackoverflow.com/questions/59285540/rewrite-python-method-depending-on-condition
try:
recrypted_snapshot = rds_client.copy_db_snapshot(
SourceDBSnapshotIdentifier=snapshot['DBSnapshotArn'],
TargetDBSnapshotIdentifier=new_snapshot_id,
KmsKeyId=kms_key['KeyMetadata']['Arn'],
Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ] # FIXME: Add custom tags
)
new_snapshot_id = self.make_snapshot_name(snapshot['DBInstanceIdentifier'], destination_account)
try:
recrypted_snapshot = rds_client.copy_db_snapshot(
SourceDBSnapshotIdentifier=snapshot['DBSnapshotArn'],
TargetDBSnapshotIdentifier=new_snapshot_id,
KmsKeyId=kms_key['KeyMetadata']['Arn'],
Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ] # FIXME: Add custom tags
)
self.wait_for_snapshot(recrypted_snapshot['DBSnapshot'], rds_client)
self.wait_for_snapshot(recrypted_snapshot['DBSnapshot'], rds_client)
logging.info("Recrypted snapshot {} with key {}".format(
recrypted_snapshot['DBSnapshot']['DBSnapshotIdentifier'],
kms_key['KeyMetadata']['Arn']
))
logging.info("Recrypted snapshot {} with key {}".format(
recrypted_snapshot['DBSnapshot']['DBSnapshotIdentifier'],
kms_key['KeyMetadata']['Arn']
))
return recrypted_snapshot['DBSnapshot']
except rds_client.exceptions.DBSnapshotAlreadyExistsFault:
snapshots = rds_client.describe_db_snapshots(DBSnapshotIdentifier=new_snapshot_id)
return recrypted_snapshot['DBSnapshot']
except rds_client.exceptions.DBSnapshotAlreadyExistsFault:
snapshots = rds_client.describe_db_snapshots(DBSnapshotIdentifier=new_snapshot_id)
logging.info("Found existing snapshot {}".format(snapshots['DBSnapshots'][0]['DBSnapshotIdentifier']))
return snapshots['DBSnapshots'][0]
except KeyError:
new_snapshot_id = self.make_snapshot_name(snapshot['DBClusterIdentifier'], destination_account)
try:
recrypted_snapshot = rds_client.copy_db_cluster_snapshot(
SourceDBClusterSnapshotIdentifier=snapshot['DBClusterSnapshotArn'],
TargetDBClusterSnapshotIdentifier=new_snapshot_id,
KmsKeyId=kms_key['KeyMetadata']['Arn'],
Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ] # FIXME: Add custom tags
)
self.wait_for_snapshot(recrypted_snapshot['DBClusterSnapshot'], rds_client)
logging.info("Recrypted snapshot {} with key {}".format(
recrypted_snapshot['DBClusterSnapshot']['DBClusterSnapshotIdentifier'],
kms_key['KeyMetadata']['Arn']
))
return recrypted_snapshot['DBClusterSnapshot']
except rds_client.exceptions.DBClusterSnapshotAlreadyExistsFault:
snapshots = rds_client.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=new_snapshot_id)
logging.info("Found existing snapshot {}".format(snapshots['DBClusterSnapshots'][0]['DBClusterSnapshotIdentifier']))
return snapshots['DBClusterSnapshots'][0]
logging.info("Found existing snapshot {}".format(snapshots['DBSnapshots'][0]['DBSnapshotIdentifier']))
return snapshots['DBSnapshots'][0]
def share_snapshot(self, snapshot, destination_account):
"""
Share [snapshot] with [destination_account]
"""
source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn)
source_rds_client.modify_db_snapshot_attribute(
DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier'],
AttributeName='restore',
ValuesToAdd=[destination_account]
)
# TODO: Yuk
# https://stackoverflow.com/questions/59285540/rewrite-python-method-depending-on-condition
try:
source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn, valid_for=14400)
source_rds_client.modify_db_snapshot_attribute(
DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier'],
AttributeName='restore',
ValuesToAdd=[destination_account]
)
logging.info("Recrypted snapshot {} has been shared with account {}".format(snapshot['DBSnapshotIdentifier'], destination_account))
except KeyError:
source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn, valid_for=14400)
source_rds_client.modify_db_cluster_snapshot_attribute(
DBClusterSnapshotIdentifier=snapshot['DBClusterSnapshotIdentifier'],
AttributeName='restore',
ValuesToAdd=[destination_account]
)
logging.info("Recrypted snapshot {} has been shared with account {}".format(snapshot['DBClusterSnapshotIdentifier'], destination_account))
logging.info("Recrypted snapshot {} has been shared with account {}".format(snapshot['DBSnapshotIdentifier'], destination_account))
def wait_for_snapshot(self, snapshot, rds_client):
"""
Check if [snapshot] is ready by querying it every 10 seconds
"""
while True:
snapshotcheck = rds_client.describe_db_snapshots(
DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier']
)['DBSnapshots'][0]
if snapshotcheck['Status'] == 'available':
logging.info("Snapshot {} complete and available!".format(snapshot['DBSnapshotIdentifier']))
break
else:
logging.info("Snapshot {} is in progress; {}% complete".format(snapshot['DBSnapshotIdentifier'], snapshotcheck['PercentProgress']))
time.sleep(10)
# TODO: Yuk
# https://stackoverflow.com/questions/59285540/rewrite-python-method-depending-on-condition
try:
while True:
snapshotcheck = rds_client.describe_db_snapshots(
DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier']
)['DBSnapshots'][0]
if snapshotcheck['Status'] == 'available':
logging.info("Snapshot {} has been created".format(snapshot['DBSnapshotIdentifier']))
break
else:
logging.info("Snapshot {} is in progress; {}% complete".format(snapshot['DBSnapshotIdentifier'], snapshotcheck['PercentProgress']))
time.sleep(10)
except KeyError:
while True:
snapshotcheck = rds_client.describe_db_cluster_snapshots(
DBClusterSnapshotIdentifier=snapshot['DBClusterSnapshotIdentifier']
)['DBClusterSnapshots'][0]
if snapshotcheck['Status'] == 'available':
logging.info("Snapshot {} has been created".format(snapshot['DBClusterSnapshotIdentifier']))
break
else:
logging.info("Snapshot {} is in progress; {}% complete".format(snapshot['DBClusterSnapshotIdentifier'], snapshotcheck['PercentProgress']))
time.sleep(10)
def take_snapshot(self, rds_client, db_name, kms_key):
"""
TODO: Take a new snapshot of [db_name] using [kms_key]. If we're here, we don't need to
recrypt, since we already have a shared key to begin with. Some of the logic in
transfer_snapshot() will need to be changed to accommodate this once ready
Take a new snapshot of [db_name] using [kms_key]
Untested.
TODO: It's not possible to take a snapshot with a CMK, really?!
"""
snapshot_name = self.make_snapshot_name(db_name, kms_key['KeyMetadata']['AWSAccountId'])
snapshot = rds_client.create_db_snapshot(
DBInstanceIdentifier=db_name,
DBSnapshotIdentifier=snapshot_name,
Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ]
)
logging.info("Snapshot created.")
try:
snapshot = rds_client.create_db_snapshot(
DBInstanceIdentifier=db_name,
DBSnapshotIdentifier=snapshot_name,
Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ]
)
self.wait_for_snapshot(snapshot['DBSnapshot'], rds_client)
logging.info("Snapshot {} created".format(snapshot['DBSnapshot']['DBSnapshotIdentifier']))
return snapshot['DBSnapshot']
except rds_client.exceptions.DBInstanceNotFoundFault:
snapshot = rds_client.create_db_cluster_snapshot(
DBClusterIdentifier=db_name,
DBClusterSnapshotIdentifier=snapshot_name,
Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ]
)
self.wait_for_snapshot(snapshot['DBClusterSnapshot'], rds_client)
logging.info("Snapshot {} created".format(snapshot['DBClusterSnapshot']['DBClusterSnapshotIdentifier']))
return snapshot['DBSnapshot']
return snapshot['DBClusterSnapshot']
......@@ -38,30 +38,30 @@ class ScanResources():
def scan_all(self):
""" Scan all resource types in scope, and return separate lists for each """
rds_arns = self.scan_rds()
aurora_arns = self.scan_aurora()
rds_instance_names = self.scan_rds_instances()
rds_aurora_names = self.scan_rds_aurora()
s3_arns = self.scan_s3()
all_arns = { **rds_arns, **aurora_arns, **s3_arns }
all_arns = { **rds_instance_names, **rds_aurora_names, **s3_arns }
return all_arns
def scan_rds(self):
""" Return list of ARNs for all RDS objects """
def scan_rds_instances(self):
""" Return list of ARNs for all RDS DB instances """
rds_client = aws_client.create_client('rds', self.region, self.role_arn)
response = rds_client.describe_db_instances()['DBInstances']
arns = [db['DBInstanceArn'] for db in response]
names = [db['DBInstanceIdentifier'] for db in response if 'DBClusterIdentifier' not in db.keys()]
return { 'rds_arns': arns }
return { 'db_names': names }
def scan_aurora(self):
def scan_rds_aurora(self):
""" Return list of ARNs for all RDS Aurora objects """
rds_client = aws_client.create_client('rds', self.region, self.role_arn)
response = rds_client.describe_db_clusters()['DBClusters']
arns = [db['DBClusterArn'] for db in response]
names = [db['DBClusterIdentifier'] for db in response]
return { 'aurora_arns': arns }
return { 'aurora_names': names }
def scan_s3(self):
""" Return list of ARNs for all S3 buckets """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment