mirror of
				https://github.com/ARM-software/workload-automation.git
				synced 2025-11-04 09:02:12 +00:00 
			
		
		
		
	utils/postgres: Relocate functions to retrieve schema information
Move the functions to retrieve schema information to general utilities to be used in other classes.
This commit is contained in:
		@@ -40,12 +40,11 @@ from wa.framework.exception import ConfigError, CommandError
 | 
			
		||||
from wa.instruments.energy_measurement import EnergyInstrumentBackend
 | 
			
		||||
from wa.utils.misc import (ensure_directory_exists as _d, capitalize,
 | 
			
		||||
                           ensure_file_directory_exists as _f)
 | 
			
		||||
from wa.utils.postgres import get_schema
 | 
			
		||||
from wa.utils.postgres import get_schema, POSTGRES_SCHEMA_DIR
 | 
			
		||||
from wa.utils.serializer import yaml
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TEMPLATES_DIR = os.path.join(os.path.dirname(__file__), 'templates')
 | 
			
		||||
POSTGRES_SCHEMA_DIR = os.path.join(os.path.dirname(__file__), 'postgres_schemas')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CreateDatabaseSubcommand(SubCommand):
 | 
			
		||||
@@ -114,7 +113,7 @@ class CreateDatabaseSubcommand(SubCommand):
 | 
			
		||||
            raise ValueError('Databasename to create cannot be postgres.')
 | 
			
		||||
 | 
			
		||||
        self._parse_args(args)
 | 
			
		||||
        self.schema_major, self.schema_minor, self.sql_commands = _get_schema(self.schemafilepath)
 | 
			
		||||
        self.schema_major, self.schema_minor, self.sql_commands = get_schema(self.schemafilepath)
 | 
			
		||||
 | 
			
		||||
        # Display the version if needed and exit
 | 
			
		||||
        if args.schema_version:
 | 
			
		||||
@@ -191,7 +190,7 @@ class CreateDatabaseSubcommand(SubCommand):
 | 
			
		||||
 | 
			
		||||
    def update_schema(self):
 | 
			
		||||
        self._validate_version()
 | 
			
		||||
        schema_major, schema_minor, _ = _get_schema(self.schemafilepath)
 | 
			
		||||
        schema_major, schema_minor, _ = get_schema(self.schemafilepath)
 | 
			
		||||
        meta_oid, current_major, current_minor = self._get_database_schema_version()
 | 
			
		||||
 | 
			
		||||
        while not (schema_major == current_major and schema_minor == current_minor):
 | 
			
		||||
@@ -209,7 +208,7 @@ class CreateDatabaseSubcommand(SubCommand):
 | 
			
		||||
            if not os.path.exists(schema_update):
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
            _, _, sql_commands = _get_schema(schema_update)
 | 
			
		||||
            _, _, sql_commands = get_schema(schema_update)
 | 
			
		||||
            self._apply_database_schema(sql_commands, major, minor, meta_oid)
 | 
			
		||||
            msg = "Updated the database schema to v{}.{}"
 | 
			
		||||
            self.logger.debug(msg.format(major, minor))
 | 
			
		||||
@@ -226,7 +225,7 @@ class CreateDatabaseSubcommand(SubCommand):
 | 
			
		||||
 | 
			
		||||
        # Reset minor to 0 with major version bump
 | 
			
		||||
        current_minor = 0
 | 
			
		||||
        _, _, sql_commands = _get_schema(schema_update)
 | 
			
		||||
        _, _, sql_commands = get_schema(schema_update)
 | 
			
		||||
        self._apply_database_schema(sql_commands, current_major, current_minor, meta_oid)
 | 
			
		||||
        msg = "Updated the database schema to v{}.{}"
 | 
			
		||||
        self.logger.debug(msg.format(current_major, current_minor))
 | 
			
		||||
@@ -567,22 +566,3 @@ def get_class_name(name, postfix=''):
 | 
			
		||||
def touch(path):
 | 
			
		||||
    with open(path, 'w') as _: # NOQA
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_schema(schemafilepath):
 | 
			
		||||
    sqlfile_path = os.path.join(
 | 
			
		||||
        POSTGRES_SCHEMA_DIR, schemafilepath)
 | 
			
		||||
 | 
			
		||||
    with open(sqlfile_path, 'r') as sqlfile:
 | 
			
		||||
        sql_commands = sqlfile.read()
 | 
			
		||||
 | 
			
		||||
    schema_major = None
 | 
			
		||||
    schema_minor = None
 | 
			
		||||
    # Extract schema version if present
 | 
			
		||||
    if sql_commands.startswith('--!VERSION'):
 | 
			
		||||
        splitcommands = sql_commands.split('!ENDVERSION!\n')
 | 
			
		||||
        schema_major, schema_minor = splitcommands[0].strip('--!VERSION!').split('.')
 | 
			
		||||
        schema_major = int(schema_major)
 | 
			
		||||
        schema_minor = int(schema_minor)
 | 
			
		||||
        sql_commands = splitcommands[1]
 | 
			
		||||
    return schema_major, schema_minor, sql_commands
 | 
			
		||||
 
 | 
			
		||||
@@ -31,7 +31,7 @@ from wa.framework.target.info import CpuInfo
 | 
			
		||||
from wa.utils.postgres import (POSTGRES_SCHEMA_DIR, cast_level, cast_vanilla,
 | 
			
		||||
                               adapt_vanilla, return_as_is, adapt_level,
 | 
			
		||||
                               ListOfLevel, adapt_ListOfX, create_iterable_adapter,
 | 
			
		||||
                               get_schema, get_database_schema_version)
 | 
			
		||||
                               get_schema_versions)
 | 
			
		||||
from wa.utils.serializer import json
 | 
			
		||||
from wa.utils.types import level
 | 
			
		||||
 | 
			
		||||
@@ -127,7 +127,7 @@ class PostgresqlResultProcessor(OutputProcessor):
 | 
			
		||||
        # N.B. Typecasters are for postgres->python and adapters the opposite
 | 
			
		||||
        self.connect_to_database()
 | 
			
		||||
        self.cursor = self.conn.cursor()
 | 
			
		||||
        self.check_schema_versions()
 | 
			
		||||
        self.verify_schema_versions()
 | 
			
		||||
 | 
			
		||||
        # Register the adapters and typecasters for enum types
 | 
			
		||||
        self.cursor.execute("SELECT NULL::status_enum")
 | 
			
		||||
@@ -520,11 +520,9 @@ class PostgresqlResultProcessor(OutputProcessor):
 | 
			
		||||
        self.conn.commit()
 | 
			
		||||
        self.conn.reset()
 | 
			
		||||
 | 
			
		||||
    def check_schema_versions(self):
 | 
			
		||||
        schemafilepath = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema.sql')
 | 
			
		||||
        cur_major_version, cur_minor_version, _ = get_schema(schemafilepath)
 | 
			
		||||
        db_schema_version = get_database_schema_version(self.cursor)
 | 
			
		||||
        if (cur_major_version, cur_minor_version) != db_schema_version:
 | 
			
		||||
    def verify_schema_versions(self):
 | 
			
		||||
        local_schema_version, db_schema_version = get_schema_versions(self.cursor)
 | 
			
		||||
        if local_schema_version != db_schema_version:
 | 
			
		||||
            self.cursor.close()
 | 
			
		||||
            self.cursor = None
 | 
			
		||||
            self.conn.commit()
 | 
			
		||||
@@ -532,8 +530,7 @@ class PostgresqlResultProcessor(OutputProcessor):
 | 
			
		||||
            msg = 'The current database schema is v{} however the local ' \
 | 
			
		||||
                  'schema version is v{}. Please update your database ' \
 | 
			
		||||
                  'with the create command'
 | 
			
		||||
            raise OutputProcessorError(msg.format(db_schema_version,
 | 
			
		||||
                                                  (cur_major_version, cur_minor_version)))
 | 
			
		||||
            raise OutputProcessorError(msg.format(db_schema_version, local_schema_version))
 | 
			
		||||
 | 
			
		||||
    def _sql_write_lobject(self, source, lobject):
 | 
			
		||||
        with open(source) as lobj_file:
 | 
			
		||||
 
 | 
			
		||||
@@ -28,6 +28,7 @@ http://initd.org/psycopg/docs/extensions.html#sql-adaptation-protocol-objects
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
import re
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
try:
 | 
			
		||||
    from psycopg2 import InterfaceError
 | 
			
		||||
@@ -39,6 +40,12 @@ except ImportError:
 | 
			
		||||
from wa.utils.types import level
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
POSTGRES_SCHEMA_DIR = os.path.join(os.path.dirname(__file__),
 | 
			
		||||
                                   '..',
 | 
			
		||||
                                   'commands',
 | 
			
		||||
                                   'postgres_schemas')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cast_level(value, cur):  # pylint: disable=unused-argument
 | 
			
		||||
    """Generic Level caster for psycopg2"""
 | 
			
		||||
    if not InterfaceError:
 | 
			
		||||
@@ -217,3 +224,37 @@ def adapt_list(param):
 | 
			
		||||
            final_string = final_string + str(item) + ","
 | 
			
		||||
        final_string = "{" + final_string + "}"
 | 
			
		||||
    return AsIs("'{}'".format(final_string))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_schema(schemafilepath):
 | 
			
		||||
    with open(schemafilepath, 'r') as sqlfile:
 | 
			
		||||
        sql_commands = sqlfile.read()
 | 
			
		||||
 | 
			
		||||
    schema_major = None
 | 
			
		||||
    schema_minor = None
 | 
			
		||||
    # Extract schema version if present
 | 
			
		||||
    if sql_commands.startswith('--!VERSION'):
 | 
			
		||||
        splitcommands = sql_commands.split('!ENDVERSION!\n')
 | 
			
		||||
        schema_major, schema_minor = splitcommands[0].strip('--!VERSION!').split('.')
 | 
			
		||||
        schema_major = int(schema_major)
 | 
			
		||||
        schema_minor = int(schema_minor)
 | 
			
		||||
        sql_commands = splitcommands[1]
 | 
			
		||||
    return schema_major, schema_minor, sql_commands
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_database_schema_version(conn):
 | 
			
		||||
    with conn.cursor() as cursor:
 | 
			
		||||
        cursor.execute('''SELECT
 | 
			
		||||
                                  DatabaseMeta.schema_major,
 | 
			
		||||
                                  DatabaseMeta.schema_minor
 | 
			
		||||
                               FROM
 | 
			
		||||
                                  DatabaseMeta;''')
 | 
			
		||||
        schema_major, schema_minor = cursor.fetchone()
 | 
			
		||||
    return (schema_major, schema_minor)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_schema_versions(conn):
 | 
			
		||||
    schemafilepath = os.path.join(POSTGRES_SCHEMA_DIR, 'postgres_schema.sql')
 | 
			
		||||
    cur_major_version, cur_minor_version, _ = get_schema(schemafilepath)
 | 
			
		||||
    db_schema_version = get_database_schema_version(conn)
 | 
			
		||||
    return (cur_major_version, cur_minor_version), db_schema_version
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user