Use automap in place of an explicit file map in migration

This commit is contained in:
Emi Simpson 2022-11-28 16:14:31 -05:00
parent f507de8256
commit d14713d077
No known key found for this signature in database
GPG Key ID: 45E9C6E81BD86E7C
1 changed files with 16 additions and 13 deletions

View File

@ -15,6 +15,8 @@ from flask import current_app
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
from pathlib import Path from pathlib import Path
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.ext.automap import automap_base
from sqlalchemy.orm import Session
import os import os
import time import time
@ -37,24 +39,23 @@ def get_max_lifespan(filesize: int) -> int:
db = SQLAlchemy(current_app.__weakref__()) db = SQLAlchemy(current_app.__weakref__())
# Representations of the original and updated File tables # Representation of the updated (future) File table
class File(db.Model):
id = db.Column(db.Integer, primary_key = True)
sha256 = db.Column(db.String, unique = True)
ext = db.Column(db.UnicodeText)
mime = db.Column(db.UnicodeText)
addr = db.Column(db.UnicodeText)
removed = db.Column(db.Boolean, default=False)
nsfw_score = db.Column(db.Float)
UpdatedFile = sa.table('file', UpdatedFile = sa.table('file',
# We only need to describe the columns that are relevent to us # We only need to describe the columns that are relevent to us
sa.column('id', db.Integer), sa.column('id', db.Integer),
sa.column('expiration', db.BigInteger) sa.column('expiration', db.BigInteger)
) )
Base = automap_base()
def upgrade(): def upgrade():
op.add_column('file', sa.Column('expiration', sa.BigInteger())) op.add_column('file', sa.Column('expiration', sa.BigInteger()))
bind = op.get_bind()
Base.prepare(autoload_with=bind)
File = Base.classes.file
session = Session(bind=bind)
storage = Path(current_app.config["FHOST_STORAGE_PATH"]) storage = Path(current_app.config["FHOST_STORAGE_PATH"])
current_time = time.time() * 1000; current_time = time.time() * 1000;
@ -63,10 +64,12 @@ def upgrade():
unexpired_files = set(os.listdir(storage)) unexpired_files = set(os.listdir(storage))
# Calculate an expiration date for all existing files # Calculate an expiration date for all existing files
files = File.query\ files = session.scalars(
.where( sa.select(File)
sa.not_(File.removed) .where(
).all() sa.not_(File.removed)
)
)
for file in files: for file in files:
if file.sha256 in unexpired_files: if file.sha256 in unexpired_files:
file_path = storage / file.sha256 file_path = storage / file.sha256