From 0426a966dab4ed66f557f9b6ca41726df69ae086 Mon Sep 17 00:00:00 2001 From: Marc Bonnici Date: Tue, 27 Nov 2018 17:34:52 +0000 Subject: [PATCH] utils/postgres: Relocate functions to retrieve schema information Move the functions to retrieve schema information to general utilities to be used in other classes. --- wa/commands/create.py | 30 ++++------------------ wa/output_processors/postgresql.py | 15 +++++------ wa/utils/postgres.py | 41 ++++++++++++++++++++++++++++++ 3 files changed, 52 insertions(+), 34 deletions(-) diff --git a/wa/commands/create.py b/wa/commands/create.py index 603feeae..7db6dd52 100644 --- a/wa/commands/create.py +++ b/wa/commands/create.py @@ -40,12 +40,11 @@ from wa.framework.exception import ConfigError, CommandError from wa.instruments.energy_measurement import EnergyInstrumentBackend from wa.utils.misc import (ensure_directory_exists as _d, capitalize, ensure_file_directory_exists as _f) -from wa.utils.postgres import get_schema +from wa.utils.postgres import get_schema, POSTGRES_SCHEMA_DIR from wa.utils.serializer import yaml TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), 'templates') -POSTGRES_SCHEMA_DIR = os.path.join(os.path.dirname(__file__), 'postgres_schemas') class CreateDatabaseSubcommand(SubCommand): @@ -114,7 +113,7 @@ class CreateDatabaseSubcommand(SubCommand): raise ValueError('Databasename to create cannot be postgres.') self._parse_args(args) - self.schema_major, self.schema_minor, self.sql_commands = _get_schema(self.schemafilepath) + self.schema_major, self.schema_minor, self.sql_commands = get_schema(self.schemafilepath) # Display the version if needed and exit if args.schema_version: @@ -191,7 +190,7 @@ class CreateDatabaseSubcommand(SubCommand): def update_schema(self): self._validate_version() - schema_major, schema_minor, _ = _get_schema(self.schemafilepath) + schema_major, schema_minor, _ = get_schema(self.schemafilepath) meta_oid, current_major, current_minor = self._get_database_schema_version() while not (schema_major == current_major and schema_minor == current_minor): @@ -209,7 +208,7 @@ class CreateDatabaseSubcommand(SubCommand): if not os.path.exists(schema_update): break - _, _, sql_commands = _get_schema(schema_update) + _, _, sql_commands = get_schema(schema_update) self._apply_database_schema(sql_commands, major, minor, meta_oid) msg = "Updated the database schema to v{}.{}" self.logger.debug(msg.format(major, minor)) @@ -226,7 +225,7 @@ class CreateDatabaseSubcommand(SubCommand): # Reset minor to 0 with major version bump current_minor = 0 - _, _, sql_commands = _get_schema(schema_update) + _, _, sql_commands = get_schema(schema_update) self._apply_database_schema(sql_commands, current_major, current_minor, meta_oid) msg = "Updated the database schema to v{}.{}" self.logger.debug(msg.format(current_major, current_minor)) @@ -567,22 +566,3 @@ def get_class_name(name, postfix=''): def touch(path): with open(path, 'w') as _: # NOQA pass - - -def _get_schema(schemafilepath): - sqlfile_path = os.path.join( - POSTGRES_SCHEMA_DIR, schemafilepath) - - with open(sqlfile_path, 'r') as sqlfile: - sql_commands = sqlfile.read() - - schema_major = None - schema_minor = None - # Extract schema version if present - if sql_commands.startswith('--!VERSION'): - splitcommands = sql_commands.split('!ENDVERSION!\n') - schema_major, schema_minor = splitcommands[0].strip('--!VERSION!').split('.') - schema_major = int(schema_major) - schema_minor = int(schema_minor) - sql_commands = splitcommands[1] - return schema_major, schema_minor, sql_commands diff --git a/wa/output_processors/postgresql.py b/wa/output_processors/postgresql.py index 5c7059de..38dd6af1 100644 --- a/wa/output_processors/postgresql.py +++ b/wa/output_processors/postgresql.py @@ -31,7 +31,7 @@ from wa.framework.target.info import CpuInfo from wa.utils.postgres import (POSTGRES_SCHEMA_DIR, cast_level, cast_vanilla, adapt_vanilla, return_as_is, adapt_level, ListOfLevel, adapt_ListOfX, create_iterable_adapter, - get_schema, get_database_schema_version) + get_schema_versions) from wa.utils.serializer import json from wa.utils.types import level @@ -127,7 +127,7 @@ class PostgresqlResultProcessor(OutputProcessor): # N.B. Typecasters are for postgres->python and adapters the opposite self.connect_to_database() self.cursor = self.conn.cursor() - self.check_schema_versions() + self.verify_schema_versions() # Register the adapters and typecasters for enum types self.cursor.execute("SELECT NULL::status_enum") @@ -520,11 +520,9 @@ class PostgresqlResultProcessor(OutputProcessor): self.conn.commit() self.conn.reset() - def check_schema_versions(self): - schemafilepath = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema.sql') - cur_major_version, cur_minor_version, _ = get_schema(schemafilepath) - db_schema_version = get_database_schema_version(self.cursor) - if (cur_major_version, cur_minor_version) != db_schema_version: + def verify_schema_versions(self): + local_schema_version, db_schema_version = get_schema_versions(self.cursor) + if local_schema_version != db_schema_version: self.cursor.close() self.cursor = None self.conn.commit() @@ -532,8 +530,7 @@ class PostgresqlResultProcessor(OutputProcessor): msg = 'The current database schema is v{} however the local ' \ 'schema version is v{}. Please update your database ' \ 'with the create command' - raise OutputProcessorError(msg.format(db_schema_version, - (cur_major_version, cur_minor_version))) + raise OutputProcessorError(msg.format(db_schema_version, local_schema_version)) def _sql_write_lobject(self, source, lobject): with open(source) as lobj_file: diff --git a/wa/utils/postgres.py b/wa/utils/postgres.py index 3a983204..1bedbbc6 100644 --- a/wa/utils/postgres.py +++ b/wa/utils/postgres.py @@ -28,6 +28,7 @@ http://initd.org/psycopg/docs/extensions.html#sql-adaptation-protocol-objects """ import re +import os try: from psycopg2 import InterfaceError @@ -39,6 +40,12 @@ except ImportError: from wa.utils.types import level +POSTGRES_SCHEMA_DIR = os.path.join(os.path.dirname(__file__), + '..', + 'commands', + 'postgres_schemas') + + def cast_level(value, cur): # pylint: disable=unused-argument """Generic Level caster for psycopg2""" if not InterfaceError: @@ -217,3 +224,37 @@ def adapt_list(param): final_string = final_string + str(item) + "," final_string = "{" + final_string + "}" return AsIs("'{}'".format(final_string)) + + +def get_schema(schemafilepath): + with open(schemafilepath, 'r') as sqlfile: + sql_commands = sqlfile.read() + + schema_major = None + schema_minor = None + # Extract schema version if present + if sql_commands.startswith('--!VERSION'): + splitcommands = sql_commands.split('!ENDVERSION!\n') + schema_major, schema_minor = splitcommands[0].strip('--!VERSION!').split('.') + schema_major = int(schema_major) + schema_minor = int(schema_minor) + sql_commands = splitcommands[1] + return schema_major, schema_minor, sql_commands + + +def get_database_schema_version(conn): + with conn.cursor() as cursor: + cursor.execute('''SELECT + DatabaseMeta.schema_major, + DatabaseMeta.schema_minor + FROM + DatabaseMeta;''') + schema_major, schema_minor = cursor.fetchone() + return (schema_major, schema_minor) + + +def get_schema_versions(conn): + schemafilepath = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema.sql') + cur_major_version, cur_minor_version, _ = get_schema(schemafilepath) + db_schema_version = get_database_schema_version(conn) + return (cur_major_version, cur_minor_version), db_schema_version