diff --git a/db.py b/db.py index 30e5228..3531aaa 100644 --- a/db.py +++ b/db.py @@ -1,9 +1,13 @@ # db.py import sqlite3 import pprint +import os def connect_db(as_object): - conn = sqlite3.connect('finance.db') + if 'ENV' in os.environ and os.environ['ENV'] == "production": + conn = sqlite3.connect('/finance.db') + else: + conn = sqlite3.connect('./finance.db') if as_object: conn.row_factory = sqlite3.Row # This allows us to access columns by name return conn