#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Copyright 2021 Recurve Analytics, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from collections import defaultdict
import sys
import csv
import logging
import sqlalchemy
import psycopg
from datetime import datetime
from flexvalue.config import FLEXValueConfig, FLEXValueException
from jinja2 import Environment, PackageLoader, select_autoescape
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.engine import Engine
from sqlalchemy.exc import ResourceClosedError
from google.cloud import bigquery
from google.cloud.exceptions import NotFound
from google import api_core
SUPPORTED_DBS = ("postgresql", "sqlite", "bigquery")
__all__ = (
"get_db_connection",
"get_deer_load_shape",
"get_filtered_acc_elec",
"get_filtered_acc_gas",
)
PROJECT_INFO_FIELDS = [
"id",
"state",
"utility",
"region",
"mwh_savings",
"therms_savings",
"load_shape",
"therms_profile",
"start_year",
"start_quarter",
"units",
"eul",
"ntg",
"discount_rate",
"admin_cost",
"measure_cost",
"incentive_cost",
"value_curve_name",
]
ELEC_AV_COSTS_FIELDS = [
"utility",
"region",
"year",
"hour_of_year",
"total",
"marginal_ghg",
"value_curve_name",
]
GAS_AV_COSTS_FIELDS = [
"state",
"utility",
"region",
"year",
"quarter",
"month",
"market",
"t_d",
"environment",
"btm_methane",
"total",
"upstream_methane",
"marginal_ghg",
"value_curve_name",
]
ELEC_AVOIDED_COSTS_FIELDS = [
"state",
"utility",
"region",
"datetime",
"year",
"quarter",
"month",
"hour_of_day",
"hour_of_year",
"energy",
"losses",
"ancillary_services",
"capacity",
"transmission",
"distribution",
"cap_and_trade",
"ghg_adder",
"ghg_rebalancing",
"methane_leakage",
"total",
"marginal_ghg",
"ghg_adder_rebalancing",
"value_curve_name",
]
logging.basicConfig(
stream=sys.stderr, format="%(levelname)s:%(message)s", level=logging.INFO
)
# This is the number of bytes to read when determining whether a csv file has
# a header. 4096 was determined empirically; I don't recommend reading fewer
# bytes than this, since some files can have many columns.
HEADER_READ_SIZE = 4096
# The number of rows to read from csv files when chunking
INSERT_ROW_COUNT = 100000
# Number of rows to insert into BigQuery at once
BIG_QUERY_CHUNK_SIZE = 10000
[docs]class DBManager:
@staticmethod
def get_db_manager(fv_config: FLEXValueConfig):
"""Factory for the correct instance of DBManager child class."""
if not fv_config.database_type or fv_config.database_type not in SUPPORTED_DBS:
raise FLEXValueException(
f"You must specify a database_type in your config file.\nThe valid choices are {SUPPORTED_DBS}"
)
if fv_config.database_type == "sqlite":
return SqliteManager(fv_config)
elif fv_config.database_type == "postgresql":
return PostgresqlManager(fv_config)
elif fv_config.database_type == "bigquery":
return BigQueryManager(fv_config)
else:
raise FLEXValueException(
f"Unsupported database_type. Please choose one of {SUPPORTED_DBS}"
)
def __init__(self, fv_config: FLEXValueConfig) -> None:
self.template_env = Environment(
loader=PackageLoader("flexvalue", "templates"),
autoescape=select_autoescape(),
trim_blocks=True,
)
self.config = fv_config
self.engine = self._get_db_engine(fv_config)
def _get_db_connection_string(self, config: FLEXValueConfig) -> str:
"""Get the sqlalchemy db connection string for the given settings."""
# Nobody should be calling the method in the base class
return ""
def _get_db_engine(self, config: FLEXValueConfig) -> Engine:
conn_str = self._get_db_connection_string(config)
logging.debug(f"conn_str ={conn_str}")
engine = create_engine(conn_str)
logging.debug(f"dialect = {engine.dialect.name}")
return engine
def _get_default_db_conn_str(self) -> str:
"""If no db config file is provided, default to a local sqlite database."""
return "sqlite+pysqlite:///flexvalue.db"
def process_elec_load_shape(self, elec_load_shapes_path: str, truncate=False):
"""Load the hourly electric load shapes (csv) file. The first 7 columns
are fixed. Then there are a variable number of columns, one for each
load shape. This function parses that file to construct a SQL INSERT
statement with the data, then inserts the data into the elec_load_shape
table.
"""
self._prepare_table(
"elec_load_shape",
"flexvalue/sql/create_elec_load_shape.sql",
# index_filepaths=["flexvalue/sql/elec_load_shape_index.sql"],
truncate=truncate,
)
rows = self._csv_file_to_rows(elec_load_shapes_path)
num_columns = len(rows[0])
buffer = []
for col in range(7, num_columns):
for row in range(1, len(rows)):
buffer.append(
{
"state": rows[row][0].upper(),
"utility": rows[row][1].upper(),
"region": rows[row][2].upper(),
"quarter": rows[row][3],
"month": rows[row][4],
"hour_of_day": rows[row][5],
"hour_of_year": rows[row][6],
"load_shape_name": rows[0][col].upper(),
"value": rows[row][col],
}
)
insert_text = self._file_to_string(
"flexvalue/templates/load_elec_load_shape.sql"
)
with self.engine.begin() as conn:
conn.execute(text(insert_text), buffer)
def process_elec_av_costs(self, elec_av_costs_path: str, truncate=False):
self._prepare_table(
"elec_av_costs",
"flexvalue/sql/create_elec_av_cost.sql",
# index_filepaths=["flexvalue/sql/elec_av_costs_index.sql"],
truncate=truncate,
)
logging.debug("about to load elec av costs")
self._load_csv_file(
elec_av_costs_path,
"elec_av_costs",
ELEC_AVOIDED_COSTS_FIELDS,
"flexvalue/templates/load_elec_av_costs.sql",
dict_processor=self._eac_dict_mapper,
)
def process_therms_profile(self, therms_profiles_path: str, truncate: bool = False):
"""Loads the therms profiles csv file. This file has 5 fixed columns and then
a variable number of columns after that, each of which represents a therms
profile. This method parses that file to construct a SQL INSERT statement, then
inserts the data into the therms_profile table."""
self._prepare_table(
"therms_profile",
"flexvalue/sql/create_therms_profile.sql",
truncate=truncate,
)
rows = self._csv_file_to_rows(therms_profiles_path)
num_columns = len(rows[0])
buffer = []
for col in range(5, num_columns):
for row in range(1, len(rows)):
buffer.append(
{
"state": rows[row][0],
"utility": rows[row][1],
"region": rows[row][2],
"quarter": rows[row][3],
"month": rows[row][4],
"profile_name": rows[0][col],
"value": rows[row][col],
}
)
insert_text = self._file_to_string(
"flexvalue/templates/load_therms_profiles.sql"
)
with self.engine.begin() as conn:
conn.execute(text(insert_text), buffer)
def process_gas_av_costs(self, gas_av_costs_path: str, truncate=False):
self._prepare_table(
"gas_av_costs", "flexvalue/sql/create_gas_av_cost.sql", truncate=truncate
)
self._load_csv_file(
gas_av_costs_path,
"gas_av_costs",
GAS_AV_COSTS_FIELDS,
"flexvalue/templates/load_gas_av_costs.sql",
)
def _eac_dict_mapper(self, dict_to_process):
dict_to_process["date_str"] = dict_to_process["datetime"][
:10
] # just the 'yyyy-mm-dd'
return dict_to_process
def _file_to_string(self, filename):
ret = None
with open(filename) as f:
ret = f.read()
return ret
def reset_elec_load_shape(self):
logging.debug("Resetting elec load shape")
self._reset_table("elec_load_shape")
def reset_elec_av_costs(self):
logging.debug("Resetting elec_av_costs")
self._reset_table("elec_av_costs")
def reset_therms_profiles(self):
logging.debug("Resetting therms_profile")
self._reset_table("therms_profile")
def reset_gas_av_costs(self):
logging.debug("Resetting gas avoided costs")
self._reset_table("gas_av_costs")
def _reset_table(self, table_name):
truncate_prefix = self._get_truncate_prefix()
sql = f"{truncate_prefix} {table_name}"
try:
with self.engine.begin() as conn:
result = conn.execute(text(sql))
except sqlalchemy.exc.ProgrammingError:
# in case this is called before the table is created
pass
def _get_truncate_prefix(self):
raise FLEXValueException(
"You need to implement _get_truncate_prefix for your database manager."
)
def _prepare_table(
self,
table_name: str,
sql_filepath: str,
index_filepaths=[],
truncate: bool = False,
):
# if the table doesn't exist, create it and all related indexes
with self.engine.begin() as conn:
if not self._table_exists(table_name):
sql = self._file_to_string(sql_filepath)
_ = conn.execute(text(sql))
for index_filepath in index_filepaths:
sql = self._file_to_string(index_filepath)
_ = conn.execute(text(sql))
if truncate:
self._reset_table(table_name)
def _prepare_table_from_str(
self,
table_name: str,
create_table_sql: str,
index_filepaths=[],
truncate: bool = False,
):
# if the table doesn't exist, create it and all related indexes
with self.engine.begin() as conn:
if not self._table_exists(table_name):
_ = conn.execute(text(create_table_sql))
for index_filepath in index_filepaths:
sql = self._file_to_string(index_filepath)
_ = conn.execute(text(sql))
if truncate:
self._reset_table(table_name)
def _table_exists(self, table_name):
inspection = inspect(self.engine)
table_exists = inspection.has_table(table_name)
return table_exists
def run(self):
logging.debug(f"About to start calculation, it is {datetime.now()}")
self._perform_calculation()
logging.debug(f"after calc, it is {datetime.now()}")
def process_project_info(self, project_info_path: str):
self._prepare_table(
"project_info",
"flexvalue/sql/create_project_info.sql",
index_filepaths=[
"flexvalue/sql/project_info_index.sql",
"flexvalue/sql/project_info_dates_index.sql",
],
truncate=True,
)
dicts = self._csv_file_to_dicts(
project_info_path,
fieldnames=PROJECT_INFO_FIELDS,
fields_to_upper=["load_shape", "state", "region", "utility"],
)
for d in dicts:
start_year = int(d["start_year"])
eul = int(d["eul"])
quarter = d["start_quarter"]
month = self._quarter_to_month(quarter)
d["start_date"] = f"{start_year}-{month}-01"
d["end_date"] = f"{start_year + eul}-{month}-01"
insert_text = self._file_to_string("flexvalue/templates/load_project_info.sql")
self._load_project_info_data(insert_text, dicts)
def _load_project_info_data(self, insert_text, project_info_dicts):
with self.engine.begin() as conn:
conn.execute(text(insert_text), project_info_dicts)
def _quarter_to_month(self, qtr):
quarter = int(qtr)
return "{:02d}".format(((quarter - 1) * 3) + 1)
def _get_empty_tables(self):
empty_tables = []
inspection = inspect(self.engine)
with self.engine.begin() as conn:
for table_name in [
"therms_profile",
"project_info",
"elec_av_costs",
"gas_av_costs",
"elec_load_shape",
]:
if not inspection.has_table(table_name):
empty_tables.append(table_name)
continue
sql = f"SELECT COUNT(*) FROM {table_name}"
result = conn.execute(text(sql))
first = result.first()
if first[0] == 0:
empty_tables.append(table_name)
return empty_tables
# TODO: allow better configuration of gas vs electric table names
def _perform_calculation(self):
empty_tables = self._get_empty_tables()
if empty_tables:
raise FLEXValueException(
f"Not all data has been loaded. Please provide data for the following tables: {', '.join(empty_tables)}"
)
if self.config.separate_output_tables:
sql = self._get_calculation_sql(mode="electric")
logging.info(f"electric sql =\n{sql}")
self._run_calc(sql)
sql = self._get_calculation_sql(mode="gas")
logging.info(f"gas sql =\n{sql}")
self._run_calc(sql)
else:
sql = self._get_calculation_sql()
logging.info(f"sql =\n{sql}")
self._run_calc(sql)
def _run_calc(self, sql):
with self.engine.begin() as conn:
result = conn.execute(text(sql))
if (
not self.config.output_table
and not self.config.electric_output_table
and not self.config.gas_output_table
):
if self.config.output_file:
with open(self.config.output_file, "w") as outfile:
outfile.write(", ".join(result.keys()) + "\n")
for row in result:
outfile.write(", ".join([str(col) for col in row]) + "\n")
else:
try:
print(", ".join(result.keys()))
for row in result:
print(", ".join([str(col) for col in row]))
except ResourceClosedError:
# If the query doesn't return rows (e.g. we are writing to
# an output table), don't error out.
pass
def _get_calculation_sql(self, mode="both"):
if mode == "both":
context = self._get_calculation_sql_context()
template = self.template_env.get_template("calculation.sql")
elif mode == "electric":
context = self._get_calculation_sql_context(mode=mode)
template = self.template_env.get_template("elec_calculation.sql")
elif mode == "gas":
context = self._get_calculation_sql_context(mode=mode)
template = self.template_env.get_template("gas_calculation.sql")
sql = template.render(context)
return sql
def _get_calculation_sql_context(self, mode=""):
elec_agg_columns = self._elec_aggregation_columns()
gas_agg_columns = self._gas_aggregation_columns()
elec_addl_fields = self._elec_addl_fields(elec_agg_columns)
gas_addl_fields = self._gas_addl_fields(gas_agg_columns)
context = {
"project_info_table": "project_info",
"eac_table": "elec_av_costs",
"els_table": "elec_load_shape",
"gac_table": "gas_av_costs",
"therms_profile_table": "therms_profile",
"float_type": self.config.float_type(),
"database_type": self.config.database_type,
"elec_components": self._elec_components(),
"gas_components": self._gas_components(),
"use_value_curve_name_for_join": self.config.use_value_curve_name_for_join,
}
if mode == "electric":
context["elec_aggregation_columns"] = elec_agg_columns
context["elec_addl_fields"] = elec_addl_fields
elif mode == "gas":
context["gas_aggregation_columns"] = gas_agg_columns
context["gas_addl_fields"] = gas_addl_fields
else:
context["elec_aggregation_columns"] = elec_agg_columns
context["gas_aggregation_columns"] = gas_agg_columns
context["elec_addl_fields"] = elec_addl_fields
context["gas_addl_fields"] = set(gas_addl_fields) - set(elec_addl_fields)
if (
self.config.output_table
or self.config.electric_output_table
or self.config.gas_output_table
):
table_name = self.config.output_table
if mode == "electric":
table_name = self.config.electric_output_table
elif mode == "gas":
table_name = self.config.gas_output_table
context[
"create_clause"
] = f"DROP TABLE IF EXISTS {table_name}; CREATE TABLE {table_name} AS ("
return context
def _elec_aggregation_columns(self):
ELECTRIC_AGG_COLUMNS = set(
[
"hour_of_year",
"year",
"region",
"month",
"quarter",
"hour_of_day",
"datetime",
]
)
aggregation_columns = (
set(self.config.aggregation_columns) & ELECTRIC_AGG_COLUMNS
)
return aggregation_columns
def _gas_aggregation_columns(self):
GAS_AGG_COLUMNS = set(
[
"region",
"year",
"month",
"quarter",
"datetime",
]
)
aggregation_columns = set(self.config.aggregation_columns) & GAS_AGG_COLUMNS
return aggregation_columns
def _elec_addl_fields(self, elec_agg_columns):
fields = set(self.config.elec_addl_fields) - set(elec_agg_columns)
logging.debug(
f"elec_addl_fields = {self.config.elec_addl_fields}\nelec_agg_columns = {elec_agg_columns}\nset diff = {fields}"
)
return fields
def _gas_addl_fields(self, gas_agg_columns):
fields = (
set(self.config.gas_addl_fields) - set(gas_agg_columns) - set(["total"])
)
logging.debug(
f"gas_addl_fields = {self.config.gas_addl_fields}\ngas_agg_columns = {gas_agg_columns}\nset diff = {fields}"
)
return fields
def _elec_components(self):
fields = set(self.config.elec_components) - set(["total"])
logging.debug(
f"elec_components = {self.config.elec_components}\n diff = {fields}"
)
return fields
def _gas_components(self):
fields = set(self.config.gas_components) - set(["total"])
logging.debug(
f"gas_components = {self.config.gas_components}\n diff = {fields}"
)
return fields
def _csv_file_to_dicts(
self, csv_file_path: str, fieldnames: str, fields_to_upper=None
):
"""Returns a dictionary representing the data in the csv file pointed
to at csv_file_path.
fields_to_upper is a list of strings. The strings in this list must
be present in the header row of the csv file being read, and are
capitalized (with string.upper()) before returning the dict."""
dicts = []
with open(csv_file_path, newline="") as f:
has_header = csv.Sniffer().has_header(f.read(HEADER_READ_SIZE))
f.seek(0)
csv_reader = csv.DictReader(f, fieldnames=fieldnames)
if has_header:
next(csv_reader)
for row in csv_reader:
processed = row
for field in fields_to_upper:
processed[field] = processed[field].upper()
dicts.append(processed)
return dicts
def _csv_file_to_rows(self, csv_file_path: str):
"""Reads a csv file into memory and returns a list of tuples representing
the data. If no header row is present, it raises a FLEXValueException."""
rows = []
with open(csv_file_path, newline="") as f:
has_header = csv.Sniffer().has_header(f.read(HEADER_READ_SIZE))
if not has_header:
raise FLEXValueException(
f"The file you provided, {csv_file_path}, \
doesn't seem to have a header row. Please provide a header row \
containing the column names."
)
f.seek(0)
csv_reader = csv.reader(f)
rows = []
# Note that we're reading the whole file into memory - don't use this on big files.
for row in csv_reader:
rows.append(row)
return rows
def _load_csv_file(
self,
csv_file_path: str,
table_name: str,
fieldnames,
load_sql_file_path: str,
dict_processor=None,
):
"""Loads the table_name table, Since some of the input data can be over a gibibyte,
the load reads in chunks of data and inserts them sequentially. The chunk size is
determined by INSERT_ROW_COUNT in this file.
fieldnames is the list of expected values in the header row of the csv file being read.
dict_processor is a function that takes a single dictionary and returns a single dictionary
"""
with open(csv_file_path, newline="") as f:
has_header = csv.Sniffer().has_header(f.read(HEADER_READ_SIZE))
f.seek(0)
csv_reader = csv.DictReader(f, fieldnames=fieldnames)
if has_header:
next(csv_reader)
buffer = []
rownum = 0
insert_text = self._file_to_string(load_sql_file_path)
with self.engine.begin() as conn:
for row in csv_reader:
buffer.append(dict_processor(row) if dict_processor else row)
rownum += 1
if rownum == INSERT_ROW_COUNT:
conn.execute(text(insert_text), buffer)
buffer = []
rownum = 0
else: # this is for/else
conn.execute(text(insert_text), buffer)
def _exec_select_sql(self, sql: str):
"""Returns a list of tuples that have been copied from the sqlalchemy result."""
# This is just here to support testing
ret = None
with self.engine.begin() as conn:
result = conn.execute(text(sql))
ret = [x for x in result]
return ret
[docs]class PostgresqlManager(DBManager):
def __init__(self, fv_config: FLEXValueConfig) -> None:
super().__init__(fv_config)
self.connection = psycopg.connect(
dbname=self.config.database,
host=self.config.host,
port=self.config.port,
user=self.config.user,
password=self.config.password,
)
logging.debug(f"connection = {self.connection}")
def _get_db_connection_string(self, config: FLEXValueConfig) -> str:
user = config.user
password = config.password
host = config.host
port = config.port
database = config.database
conn_str = f"postgresql+psycopg://{user}:{password}@{host}:{port}/{database}"
return conn_str
def _get_truncate_prefix(self):
return "TRUNCATE TABLE"
def process_gas_av_costs(self, gas_av_costs_path: str, truncate=False):
def copy_write(cur, rows):
with cur.copy(
"""COPY gas_av_costs (
state,
utility,
region,
year,
quarter,
month,
datetime,
market,
t_d,
environment,
btm_methane,
total,
upstream_methane,
marginal_ghg,
value_curve_name)
FROM STDIN"""
) as copy:
for row in rows:
copy.write_row(row)
self._prepare_table("gas_av_costs", "flexvalue/sql/create_gas_av_cost.sql")
MAX_ROWS = 10000
logging.info("IN PG VERSION OF LOAD GAS AV COSTS")
try:
cur = self.connection.cursor()
buf = []
with open(gas_av_costs_path) as f:
reader = csv.DictReader(f)
for i, r in enumerate(reader):
dt = datetime(
year=int(r["year"]),
month=int(r["month"]),
day=1,
hour=0,
minute=0,
second=0,
)
gac_timestamp = dt.strftime("%Y-%m-%d %H:%M:%S %Z")
buf.append(
[
r["state"],
r["utility"],
r["region"],
int(r["year"]),
int(r["quarter"]),
int(r["month"]),
gac_timestamp,
float(r["market"]),
float(r["t_d"]),
float(r["environment"]),
float(r["btm_methane"]),
float(r["total"]),
float(r["upstream_methane"]),
float(r["marginal_ghg"]),
r["value_curve_name"],
]
)
if len(buf) == MAX_ROWS:
copy_write(cur, buf)
buf = []
else:
copy_write(cur, buf)
self.connection.commit()
except Exception as e:
logging.error(f"Error loading the gas avoided costs: {e}")
def process_elec_av_costs(self, elec_av_costs_path):
def copy_write(cur, rows):
with cur.copy(
"""COPY elec_av_costs (
state,
utility,
region,
datetime,
year,
quarter,
month,
hour_of_day,
hour_of_year,
energy,
losses,
ancillary_services,
capacity,
transmission,
distribution,
cap_and_trade,
ghg_adder,
ghg_rebalancing,
methane_leakage,
total,
marginal_ghg,
ghg_adder_rebalancing,
value_curve_name)
FROM STDIN"""
) as copy:
for row in rows:
copy.write_row(row)
self._prepare_table(
"elec_av_costs",
"flexvalue/sql/create_elec_av_cost.sql",
# index_filepaths=["flexvalue/sql/elec_av_costs_index.sql"]
)
logging.debug("in pg version of load_elec_av_costs")
MAX_ROWS = 10000
try:
cur = self.connection.cursor()
buf = []
with open(elec_av_costs_path) as f:
reader = csv.DictReader(f)
for i, r in enumerate(reader):
eac_timestamp = datetime.strptime(
r["datetime"], "%Y-%m-%d %H:%M:%S %Z"
)
buf.append(
[
r["state"],
r["utility"],
r["region"],
eac_timestamp,
r["year"],
r["quarter"],
r["month"],
r["hour_of_day"],
r["hour_of_year"],
r["energy"],
r["losses"],
r["ancillary_services"],
r["capacity"],
r["transmission"],
r["distribution"],
r["cap_and_trade"],
r["ghg_adder"],
r["ghg_rebalancing"],
r["methane_leakage"],
float(r["total"]),
r["marginal_ghg"],
r["ghg_adder_rebalancing"],
r["value_curve_name"],
]
)
if len(buf) == MAX_ROWS:
copy_write(cur, buf)
buf = []
else:
copy_write(cur, buf)
self.connection.commit()
except Exception as e:
logging.error(f"Error loading the electric avoided costs: {e}")
def process_elec_load_shape(self, elec_load_shapes_path: str):
def copy_write(cur, rows):
with cur.copy(
"COPY elec_load_shape (state, utility, region, quarter, month, hour_of_day, hour_of_year, load_shape_name, value) FROM STDIN"
) as copy:
for row in rows:
copy.write_row(row)
self._prepare_table(
"elec_load_shape",
"flexvalue/sql/create_elec_load_shape.sql",
# index_filepaths=["flexvalue/sql/elec_load_shape_index.sql"]
)
cur = self.connection.cursor()
# if you're concerned about RAM change this to sane number
MAX_ROWS = 10000
buf = []
with open(elec_load_shapes_path) as f:
# this probably escapes fine but a csv reader is a safer bet
columns = f.readline().split(",")
load_shape_names = [
c.strip()
for c in columns
if columns.index(c) > columns.index("hour_of_year")
]
f.seek(0)
reader = csv.DictReader(f)
for r in reader:
for load_shape in load_shape_names:
buf.append(
(
r["state"].upper(),
r["utility"].upper(),
r["region"].upper(),
int(r["quarter"]),
int(r["month"]),
int(r["hour_of_day"]),
int(r["hour_of_year"]),
load_shape.upper(),
float(r[load_shape]),
)
)
if len(buf) >= MAX_ROWS:
copy_write(cur, buf)
buf = []
else:
copy_write(cur, buf)
self.connection.commit()
def process_metered_load_shape(self, metered_load_shape_path: str):
"""Note this has to be run after process_project_info, as it depends
on the utility for each project having been loaded"""
def copy_write(cur, rows):
with cur.copy(
"COPY elec_load_shape (hour_of_year, utility, load_shape_name, value) FROM STDIN"
) as copy:
for row in rows:
copy.write_row(row)
# get the list of load shape names we care about from project_info
metered_load_shape_query = "SELECT distinct utility, load_shape from project_info where load_shape not in (select distinct load_shape_name from elec_load_shape);"
load_shapes_utils = defaultdict(list)
with self.engine.begin() as conn:
result = conn.execute(text(metered_load_shape_query))
for row in result:
load_shapes_utils[row[1].upper()].append(row[0])
# get the load shapes in this file
with open(metered_load_shape_path) as f:
# this probably escapes fine but a csv reader is a safer bet
columns = [x.strip() for x in f.readline().split(",")]
metered_load_shapes = [
c.strip()
for c in columns
if columns.index(c) > columns.index("hour_of_year")
]
cur = self.connection.cursor()
MAX_ROWS = 10000
buf = []
# This is so deeply nested because the project info could have more
# than one utility per a given metered load shape.
with open(metered_load_shape_path) as f:
reader = csv.DictReader(f)
for row in reader:
for load_shape in metered_load_shapes:
try:
utils = load_shapes_utils[load_shape.upper()]
except KeyError:
# If load shape not in load_shapes_utils, don't load it
continue
for util in utils:
buf.append(
[
int(row["hour_of_year"]),
util.upper(),
load_shape.upper(),
float(row[load_shape]),
]
)
if len(buf) >= MAX_ROWS:
copy_write(cur, buf)
buf = []
else:
copy_write(cur, buf)
self.connection.commit()
def _load_project_info_data(self, insert_text, project_info_dicts):
"""insert_text isn't needed for postgresql"""
def copy_write(cur, rows):
with cur.copy(
"COPY project_info (id, state, utility, region, mwh_savings, therms_savings, load_shape, therms_profile, start_year, start_quarter, start_date, end_date, units, eul, ntg, discount_rate, admin_cost, measure_cost, incentive_cost, value_curve_name) FROM STDIN"
) as copy:
for row in rows:
copy.write_row(row)
rows = [
(
x["id"],
x["state"],
x["utility"],
x["region"],
x["mwh_savings"],
x["therms_savings"],
x["load_shape"],
x["therms_profile"],
x["start_year"],
x["start_quarter"],
x["start_date"],
x["end_date"],
x["units"],
x["eul"],
x["ntg"],
x["discount_rate"],
x["admin_cost"],
x["measure_cost"],
x["incentive_cost"],
x["value_curve_name"],
)
for x in project_info_dicts
]
cursor = self.connection.cursor()
copy_write(cursor, rows)
self.connection.commit()
class SqliteManager(DBManager):
def __init__(self, fv_config: FLEXValueConfig):
super().__init__(fv_config)
self.template_env = Environment(
loader=PackageLoader("flexvalue", "templates"),
autoescape=select_autoescape(),
)
self.config = fv_config
def _get_truncate_prefix(self):
"""sqlite doesn't support TRUNCATE"""
return "DELETE FROM"
def _get_db_connection_string(self, config: FLEXValueConfig) -> str:
database = config.database
conn_str = f"sqlite+pysqlite://{database}"
return conn_str
[docs]class BigQueryManager(DBManager):
def __init__(self, fv_config: FLEXValueConfig):
super().__init__(fv_config)
self.template_env = Environment(
loader=PackageLoader("flexvalue", "templates"),
autoescape=select_autoescape(),
)
self.config = fv_config
self.table_names = [
self.config.elec_av_costs_table,
self.config.gas_av_costs_table,
self.config.elec_load_shape_table,
self.config.therms_profiles_table,
self.config.project_info_table,
]
self.client = bigquery.Client(project=self.config.project)
def _get_target_dataset(self):
# Use an output table because we know those have write permissions
if self.config.separate_output_tables:
# both are required, so just pick one:
return ".".join(self.config.electric_output_table.split(".")[:-1])
else:
return ".".join(self.config.output_table.split(".")[:-1])
def _test_connection(self):
logging.debug("in bigquerymanager._test_connection")
query = """select count(*) from flexvalue_refactor_tables.example_user_inputs"""
query_job = self.client.query(query)
rows = query_job.result()
for row in rows:
print(f"There are {row.values()[0]} rows in example_user_inputs")
def _get_truncate_prefix(self):
# in BQ, TRUNCATE TABLE deletes row-level security, so using DELETE instead:
return "DELETE"
def _get_db_engine(self, config: FLEXValueConfig) -> Engine:
# Not using sqlalchemy in BigQuery; TODO refactor so this isn't necessary
return None
def _table_exists(self, table_name):
# This is basically straight from the google docs:
# https://cloud.google.com/bigquery/docs/samples/bigquery-table-exists#bigquery_table_exists-python
try:
self.client.get_table(table_name)
return True
except NotFound:
return False
def _get_empty_tables(self):
empty_tables = []
for table_name in self.table_names:
if not self._table_exists(table_name):
empty_tables.append(table_name)
continue
sql = f"SELECT COUNT(*) FROM {table_name}"
query_job = self.client.query(sql) # API request
result = query_job.result()
for row in result: # there will be only one, but we have to iterate
if row.get("count") == 0:
empty_tables.append(table_name)
return empty_tables
def _prepare_table(
self,
table_name: str,
sql_filepath: str,
index_filepaths=[],
truncate: bool = False,
):
"""table_name: includes the dataset for the table
sql_filepath: the path to the template that will be rendered to produce the preparation sql
truncate: if True, all data will be removed from the table; the table will not be dropped
"""
if not self._table_exists(table_name):
dataset = ".".join(
table_name.split(".")[:-1]
) # get everything before last '.'
template = self.template_env.get_template(sql_filepath)
sql = template.render({"dataset": dataset})
logging.debug(f"create sql = \n{sql}")
query_job = self.client.query(sql)
result = query_job.result()
else:
if truncate:
sql = f"DELETE FROM {table_name} WHERE TRUE;"
query_job = self.client.query(sql)
result = query_job.result()
def process_elec_av_costs(self, elec_av_costs_path: str, truncate=False):
# We don't need to do anything with this in BQ, just use the table provided
pass
def process_gas_av_costs(self, gas_av_costs_path: str, truncate=False):
"""Add a datetime column if none exists, and populate it. It
will be used to join on in later calculations.
"""
logging.debug("In bq process_gas_av_costs")
self._ensure_datetime_column(self.config.gas_av_costs_table)
sql = f'UPDATE {self.config.gas_av_costs_table} gac SET datetime = (DATETIME(FORMAT("%d-%d-01 00:00:00", gac.year, gac.month))) WHERE TRUE;'
query_job = self.client.query(sql)
result = query_job.result()
def _ensure_datetime_column(self, table_name):
"""Ensure that the table with name `table_name` has a column
named `datetime`, of type `DATETIME`.
"""
table = self.client.get_table(table_name)
has_datetime = False
for column in table.schema:
if column.name == "datetime" and column.field_type == "DATETIME":
has_datetime = True
break
if not has_datetime:
original_schema = table.schema
new_schema = original_schema[:] # Creates a copy of the schema.
new_schema.append(bigquery.SchemaField("datetime", "DATETIME"))
table.schema = new_schema
table = self.client.update_table(table, ["schema"]) # Make an API request.
if len(table.schema) == len(original_schema) + 1 == len(new_schema):
print("A new column has been added.")
else:
raise FLEXValueException(
f"Unable to add a datetime column to {table_name}; can't process gas avoided costs."
)
def _copy_table(self, source_table, target_table):
"""source_table and target_table must include the dataset in their values, like {dataset}.{table}.
This deletes target_table before copying source_table to it.
"""
self.client.delete_table(target_table, not_found_ok=True)
job_config = bigquery.CopyJobConfig()
job_config.write_disposition = bigquery.WriteDisposition.WRITE_TRUNCATE
job_config.create_disposition = bigquery.CreateDisposition.CREATE_IF_NEEDED
copy_job = self.client.copy_table(
source_table, target_table, job_config=job_config
)
copy_job.result()
def process_elec_load_shape(self, elec_load_shapes_path: str, truncate=False):
"""Transforms data in the table specified by config.elec_load_shape_table, and loads it into `elec_load_shape`."""
dataset = self._get_target_dataset()
self._prepare_table(
f"{dataset}.elec_load_shape", "bq_create_elec_load_shape.sql", truncate=True
)
template = self.template_env.get_template("bq_populate_elec_load_shape.sql")
# Black ruins readability here, disable
# fmt: off
sql = template.render(
{
"target_dataset": dataset,
"source_dataset": ".".join(self.config.elec_load_shape_table.split(".")[:-1]),
"elec_load_shape_table": self.config.elec_load_shape_table,
"elec_load_shape_table_name_only": self.config.elec_load_shape_table.split(".")[-1],
}
)
# fmt: on
logging.info(f"elec_load_shape sql = {sql}")
query_job = self.client.query(sql)
result = query_job.result()
def process_metered_load_shape(self, metered_load_shapes_path: str, truncate=False):
"""Transforms data in the table specified by config.metered_load_shape_table, and
loads it into `elec_load_shape`. First copies the specified elec_load_shape table
into {target_dataset}.elec_load_shape."""
dataset = self._get_target_dataset()
self._prepare_table(
f"{dataset}.elec_load_shape",
"bq_create_elec_load_shape.sql",
truncate=truncate,
)
# if we process the elec load shape, that will create target_dataset.elec_load_shape; otherwise...
if not self.config.process_elec_load_shape:
self._copy_table(
self.config.elec_load_shape_table,
f"{self._get_target_dataset()}.elec_load_shape",
)
template = self.template_env.get_template("bq_populate_metered_load_shape.sql")
# Black ruins readability here, disable
# fmt: off
sql = template.render(
{
"source_dataset": ".".join(self.config.metered_load_shape_table.split(".")[:-1]),
"target_dataset": dataset,
"project_info_table": self.config.project_info_table,
"metered_load_shape_table": self.config.metered_load_shape_table,
"metered_load_shape_table_only_name": self.config.metered_load_shape_table.split(".")[-1],
}
)
# fmt: on
logging.info(f"metered_load_shape sql = {sql}")
query_job = self.client.query(sql)
result = query_job.result()
def process_therms_profile(self, therms_profiles_path: str, truncate: bool = False):
"""Transforms data in the table specified by config.therms_profile_table, and loads it into `therms_profile`."""
dataset = self._get_target_dataset()
self._prepare_table(
f"{dataset}.therms_profile",
"bq_create_therms_profile.sql",
truncate=truncate,
)
template = self.template_env.get_template("bq_populate_therms_profile.sql")
# Black ruins readability here, disable
# fmt: off
sql = template.render(
{
"source_dataset": ".".join(self.config.therms_profiles_table.split(".")[:-1]),
"target_dataset": dataset,
"therms_profiles_table": self.config.therms_profiles_table,
"therms_profiles_table_only_name": self.config.therms_profiles_table.split(".")[-1],
}
)
# fmt: on
logging.debug(f"therms_profile sql = {sql}")
query_job = self.client.query(sql)
result = query_job.result()
def _elec_load_shape_for_context(self):
if (
self.config.process_metered_load_shape
or self.config.process_elec_load_shape
):
return f"{self._get_target_dataset()}.elec_load_shape"
return self.config.elec_load_shape_table
def _therms_profile_for_context(self):
if self.config.process_therms_profiles:
return f"{self._get_target_dataset()}.therms_profile"
return self.config.therms_profiles_table
def _get_calculation_sql_context(self, mode=""):
elec_agg_columns = self._elec_aggregation_columns()
gas_agg_columns = self._gas_aggregation_columns()
elec_addl_fields = self._elec_addl_fields(elec_agg_columns)
gas_addl_fields = self._gas_addl_fields(gas_agg_columns)
# TODO double-check this: should the av_costs tables be treated the same as the load shapes?
context = {
"project_info_table": self.config.project_info_table,
"eac_table": self.config.elec_av_costs_table,
"els_table": self._elec_load_shape_for_context(),
"gac_table": self.config.gas_av_costs_table,
"therms_profile_table": self._therms_profile_for_context(),
"float_type": self.config.float_type(),
"database_type": self.config.database_type,
"elec_components": self._elec_components(),
"gas_components": self._gas_components(),
"use_value_curve_name_for_join": self.config.use_value_curve_name_for_join,
}
if mode == "electric":
context["elec_aggregation_columns"] = elec_agg_columns
context["elec_addl_fields"] = elec_addl_fields
elif mode == "gas":
context["gas_aggregation_columns"] = gas_agg_columns
context["gas_addl_fields"] = gas_addl_fields
else:
context["elec_aggregation_columns"] = elec_agg_columns
context["gas_aggregation_columns"] = gas_agg_columns
context["elec_addl_fields"] = elec_addl_fields
context["gas_addl_fields"] = set(gas_addl_fields) - set(elec_addl_fields)
if (
self.config.output_table
or self.config.electric_output_table
or self.config.gas_output_table
):
table_name = self.config.output_table
if mode == "electric":
table_name = self.config.electric_output_table
elif mode == "gas":
table_name = self.config.gas_output_table
context["create_clause"] = f"CREATE OR REPLACE TABLE {table_name} AS ("
return context
def _run_calc(self, sql):
query_job = self.client.query(sql)
result = query_job.result()
if (
not self.config.output_table
and not self.config.electric_output_table
and not self.config.gas_output_table
):
if self.config.output_file:
with open(self.config.output_file, "w") as outfile:
for row in result:
outfile.write(",".join([f"{x}" for x in row.values()]) + "\n")
else:
for row in result:
print(",".join([f"{x}" for x in row.values()]))
def process_project_info(self, project_info_path: str):
pass
def reset_elec_av_costs(self):
# The elec avoided costs table doesn't get changed; the super()'s
# reset_elec_av_costs will truncate this table, so add a no-op here.
pass
def reset_gas_av_costs(self):
# FLEXvalue adds and populates the `datetime` column, so remove it:
sql = f"ALTER TABLE {self.config.gas_av_costs_table} DROP COLUMN datetime;"
query_job = self.client.query(sql)
try:
result = query_job.result()
except api_core.exceptions.BadRequest as e:
# We are resetting before datetime was added, ignore exception
pass
def reset_elec_load_shape(self):
logging.debug("Resetting elec load shape")
self._reset_table(f"{self._get_target_dataset()}.elec_load_shape")
def reset_therms_profiles(self):
logging.debug("Resetting therms_profile")
self._reset_table(f"{self._get_target_dataset()}.therms_profile")
def _reset_table(self, table_name):
truncate_prefix = self._get_truncate_prefix()
try:
sql = f"{truncate_prefix} {table_name} WHERE TRUE;"
query_job = self.client.query(sql)
result = query_job.result()
except NotFound as e:
# If the table doesn't exist yet, it will be created later
pass
def _exec_select_sql(self, sql: str):
# This is just here to support testing
query_job = self.client.query(sql)
result = query_job.result()
return [x for x in result]