transfer_snapshot.py 12.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
"""
Sharing snapshots between AWS accounts involves:

1. Creating a key to share between those two accounts, and sharing it
2. (Re)encrypting a snapshot from the live account with the shared key
3. Creating a key on the destination account
4. Copying and re-encrypting the copy to the destination account with that key

This module has all the methods needed to do that, and uses them in the entrypoint
method; transfer_snapshot()
"""

#!/usr/bin/env python3

from datetime import datetime
from operator import itemgetter
from akinaka.client.aws_client import AWS_Client
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
18
from akinaka.libs import helpers, exceptions
19 20 21 22 23 24 25 26
import logging
import time

helpers.set_logger()
aws_client = AWS_Client()

class TransferSnapshot():
    def __init__(
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
27 28 29 30 31 32
            self,
            region,
            source_role_arn,
            destination_role_arn,
            source_kms_key,
            destination_kms_key
33 34 35 36 37
        ):

        self.region = region
        self.source_role_arn = source_role_arn
        self.destination_role_arn = destination_role_arn
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
38 39
        self.source_kms_key = source_kms_key
        self.destination_kms_key = destination_kms_key
40

41
    def transfer_snapshot(self, take_snapshot, db_names, source_account, destination_account, keep, retention):
42
        """
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
43 44
        For every DB in [db_names], call methods to perform the actions listed in this module's
        docstring. Additionally, rotate the oldest snapshot out, if there are more than [retention]
45 46
        """

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
47 48
        for db_name in db_names:
            source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn, valid_for=14400)
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
49

50
            if take_snapshot:
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
51 52
                source_snapshot = self.take_snapshot(source_rds_client, db_name, self.source_kms_key)
                logging.info("Will now recrypt it with the shared key")
53
            else:
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
54
                source_snapshot = self.get_latest_snapshot(db_name)
55

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
56
            recrypted_snapshot = self.recrypt_snapshot(source_rds_client, source_snapshot, self.source_kms_key, source_account)
57

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
58
            self.share_snapshot(recrypted_snapshot, destination_account)
59

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
60 61
            logging.info('The snapshot must now be recrypted and copied with a key available only to the destination account')
            destination_rds_client = aws_client.create_client('rds', self.region, self.destination_role_arn, valid_for=14400)
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
62
            self.recrypt_snapshot(destination_rds_client, recrypted_snapshot, self.destination_kms_key, destination_account)
63

64
            self.rotate_snapshots(retention, db_name, keep=None)
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
65

66
    def rotate_snapshots(self, retention, db_name, keep):
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
67 68
        """
        Get all the snapshots for [db_name], and delete the oldest one if there are more than
69
        [retention] of them. Ignore any in the list [keep].
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
70 71 72 73 74

        Beware, this does not take distinct days into account, only the number of snapshots. So if you
        take more than [retention] snapshots in one day, all previous snapshots will be deleted
        """

75 76
        keep = keep or []

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
77 78 79 80 81
        destination_rds_client = aws_client.create_client('rds', self.region, self.destination_role_arn, valid_for=14400)

        snapshots = destination_rds_client.describe_db_snapshots(DBInstanceIdentifier=db_name)['DBSnapshots']
        if len(snapshots) > retention:
            oldest_snapshot = sorted(snapshots, key=itemgetter('SnapshotCreateTime'))[-1]
82 83 84 85 86 87 88 89 90 91 92 93 94

            if oldest_snapshot['DBSnapshotIdentifier'] not in keep:
                logging.info("There are more than the given retention number of snapshots in the account," \
                    "so we're going to delete the oldest: {}".format(oldest_snapshot['DBSnapshotIdentifier'])
                )

                destination_rds_client.delete_db_snapshot(
                    DBSnapshotIdentifier=oldest_snapshot['DBSnapshotIdentifier']
                )
            else:
                logging.info("Oldest snapshot ({}) is older than" \
                             "the retention period allows, but it's " \
                             "the --keep list so it will not be deleted".format(oldest_snapshot['DBSnapshotIdentifier']))
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
95

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
96
    def get_latest_snapshot(self, db_name):
97
        """
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
98
        Return the latest snapshot for [db_name], where the ARN can also be the name of the DB
99

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
100
        Note: You can only use the db_name if you are in the account with the DB in it, else you
101 102 103
              must use the DB name
        """

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
104 105 106
        source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn, valid_for=14400)
        # TODO: Yuk
        #       https://stackoverflow.com/questions/59285540/rewrite-python-method-depending-on-condition
107
        try:
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
108
            snapshots = source_rds_client.describe_db_snapshots(DBInstanceIdentifier=db_name)['DBSnapshots']
109
            latest = sorted(snapshots, key=itemgetter('SnapshotCreateTime'))[0]
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
110 111 112 113 114 115 116
            logging.info("Using snapshot {}".format(latest['DBSnapshotIdentifier']))
        except IndexError:
            snapshots = source_rds_client.describe_db_cluster_snapshots(DBClusterIdentifier=db_name)['DBClusterSnapshots']
            if len(snapshots) == 0:
                raise exceptions.AkinakaCriticalException("No snapshots found for {}. You'll need to take one first with --take-snapshot".format(db_name))
            latest = sorted(snapshots, key=itemgetter('SnapshotCreateTime'))[0]
            logging.info("Using snapshot {}".format(latest['DBClusterSnapshotIdentifier']))
117 118 119

        return latest

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
120
    def make_snapshot_name(self, db_name, account):
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
121 122
        """ Make a name based on [db_name] and [account] """

123 124
        date = datetime.utcnow().strftime('%Y%m%d-%H%M')

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
125
        return "{}-{}-{}".format(db_name, date, account)
126

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
127
    def recrypt_snapshot(self, rds_client, snapshot, kms_key, destination_account, tags=None):
128 129 130 131
        """
        Recrypt a snapshot [snapshot] with the KMS key [kms_key]. Return the recrypted snapshot.
        """

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
132 133
        # TODO: Yuk
        #       https://stackoverflow.com/questions/59285540/rewrite-python-method-depending-on-condition
134
        try:
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
135 136 137 138 139 140 141 142 143 144 145
            new_snapshot_id = self.make_snapshot_name(snapshot['DBInstanceIdentifier'], destination_account)

            try:
                recrypted_snapshot = rds_client.copy_db_snapshot(
                    SourceDBSnapshotIdentifier=snapshot['DBSnapshotArn'],
                    TargetDBSnapshotIdentifier=new_snapshot_id,
                    KmsKeyId=kms_key['KeyMetadata']['Arn'],
                    Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ] # FIXME: Add custom tags
                )

                self.wait_for_snapshot(recrypted_snapshot['DBSnapshot'], rds_client)
146

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
147 148 149 150
                logging.info("Recrypted snapshot {} with key {}".format(
                        recrypted_snapshot['DBSnapshot']['DBSnapshotIdentifier'],
                        kms_key['KeyMetadata']['Arn']
                    ))
151

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
152 153 154
                return recrypted_snapshot['DBSnapshot']
            except rds_client.exceptions.DBSnapshotAlreadyExistsFault:
                snapshots = rds_client.describe_db_snapshots(DBSnapshotIdentifier=new_snapshot_id)
155

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
                logging.info("Found existing snapshot {}".format(snapshots['DBSnapshots'][0]['DBSnapshotIdentifier']))
                return snapshots['DBSnapshots'][0]

        except KeyError:
            new_snapshot_id = self.make_snapshot_name(snapshot['DBClusterIdentifier'], destination_account)

            try:
                recrypted_snapshot = rds_client.copy_db_cluster_snapshot(
                    SourceDBClusterSnapshotIdentifier=snapshot['DBClusterSnapshotArn'],
                    TargetDBClusterSnapshotIdentifier=new_snapshot_id,
                    KmsKeyId=kms_key['KeyMetadata']['Arn'],
                    Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ] # FIXME: Add custom tags
                )

                self.wait_for_snapshot(recrypted_snapshot['DBClusterSnapshot'], rds_client)

                logging.info("Recrypted snapshot {} with key {}".format(
                        recrypted_snapshot['DBClusterSnapshot']['DBClusterSnapshotIdentifier'],
                        kms_key['KeyMetadata']['Arn']
                    ))

                return recrypted_snapshot['DBClusterSnapshot']
            except rds_client.exceptions.DBClusterSnapshotAlreadyExistsFault:
                snapshots = rds_client.describe_db_cluster_snapshots(DBClusterSnapshotIdentifier=new_snapshot_id)

                logging.info("Found existing snapshot {}".format(snapshots['DBClusterSnapshots'][0]['DBClusterSnapshotIdentifier']))
                return snapshots['DBClusterSnapshots'][0]
183 184 185 186 187 188 189


    def share_snapshot(self, snapshot, destination_account):
        """
        Share [snapshot] with [destination_account]
        """

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
        # TODO: Yuk
        #       https://stackoverflow.com/questions/59285540/rewrite-python-method-depending-on-condition
        try:
            source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn, valid_for=14400)
            source_rds_client.modify_db_snapshot_attribute(
                DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier'],
                AttributeName='restore',
                ValuesToAdd=[destination_account]
            )

            logging.info("Recrypted snapshot {} has been shared with account {}".format(snapshot['DBSnapshotIdentifier'], destination_account))
        except KeyError:
            source_rds_client = aws_client.create_client('rds', self.region, self.source_role_arn, valid_for=14400)
            source_rds_client.modify_db_cluster_snapshot_attribute(
                DBClusterSnapshotIdentifier=snapshot['DBClusterSnapshotIdentifier'],
                AttributeName='restore',
                ValuesToAdd=[destination_account]
            )

            logging.info("Recrypted snapshot {} has been shared with account {}".format(snapshot['DBClusterSnapshotIdentifier'], destination_account))
210 211 212 213 214 215 216


    def wait_for_snapshot(self, snapshot, rds_client):
        """
        Check if [snapshot] is ready by querying it every 10 seconds
        """

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        # TODO: Yuk
        #       https://stackoverflow.com/questions/59285540/rewrite-python-method-depending-on-condition
        try:
            while True:
                snapshotcheck = rds_client.describe_db_snapshots(
                    DBSnapshotIdentifier=snapshot['DBSnapshotIdentifier']
                )['DBSnapshots'][0]
                if snapshotcheck['Status'] == 'available':
                    logging.info("Snapshot {} has been created".format(snapshot['DBSnapshotIdentifier']))
                    break
                else:
                    logging.info("Snapshot {} is in progress; {}% complete".format(snapshot['DBSnapshotIdentifier'], snapshotcheck['PercentProgress']))
                    time.sleep(10)
        except KeyError:
            while True:
                snapshotcheck = rds_client.describe_db_cluster_snapshots(
                    DBClusterSnapshotIdentifier=snapshot['DBClusterSnapshotIdentifier']
                )['DBClusterSnapshots'][0]
                if snapshotcheck['Status'] == 'available':
                    logging.info("Snapshot {} has been created".format(snapshot['DBClusterSnapshotIdentifier']))
                    break
                else:
                    logging.info("Snapshot {} is in progress; {}% complete".format(snapshot['DBClusterSnapshotIdentifier'], snapshotcheck['PercentProgress']))
                    time.sleep(10)
241

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
242
    def take_snapshot(self, rds_client, db_name, kms_key):
243
        """
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
244
        Take a new snapshot of [db_name] using [kms_key]
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
245

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
246
        TODO: It's not possible to take a snapshot with a CMK, really?!
247
        """
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
248 249 250

        snapshot_name = self.make_snapshot_name(db_name, kms_key['KeyMetadata']['AWSAccountId'])

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272
        try:
            snapshot = rds_client.create_db_snapshot(
                DBInstanceIdentifier=db_name,
                DBSnapshotIdentifier=snapshot_name,
                Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ]
            )

            self.wait_for_snapshot(snapshot['DBSnapshot'], rds_client)

            logging.info("Snapshot {} created".format(snapshot['DBSnapshot']['DBSnapshotIdentifier']))

            return snapshot['DBSnapshot']
        except rds_client.exceptions.DBInstanceNotFoundFault:
            snapshot = rds_client.create_db_cluster_snapshot(
                DBClusterIdentifier=db_name,
                DBClusterSnapshotIdentifier=snapshot_name,
                Tags=[ { 'Key': 'akinaka-made', 'Value': 'true' }, ]
            )

            self.wait_for_snapshot(snapshot['DBClusterSnapshot'], rds_client)

            logging.info("Snapshot {} created".format(snapshot['DBClusterSnapshot']['DBClusterSnapshotIdentifier']))
Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
273

Afraz Ahmadzadeh's avatar
Afraz Ahmadzadeh committed
274
            return snapshot['DBClusterSnapshot']