import pandas as pd
import psycopg2
import concurrent.futures

import time
import psycopg2
from psycopg2 import Error
import random
import matplotlib.pyplot as plt

#To setup postgres
#install postgres package
#on command line run:
#  > createdb test
#  > pg_ctl -D test start -l logfile
#  > initdb test
#  > psql test
# In psql shell, run
#  > select i as a, i as b, i as c into t from generate_Series(1,1000) as i;
#  > create index i on t(a);

def run_xactions(isolation_level):
    connection = psycopg2.connect(user="",
                                  password="",
                                  host="127.0.0.1",
                                  port="5432",
                                  database="test")
    cursor = connection.cursor()
    connection.isolation_level = isolation_level
    i = 0
    ntries = 0
    while i < 1000:
        try:
            i = i + 1
            ntries = ntries + 1
            cursor.execute("select count(*) from t;")
            record = cursor.fetchone();
            cursor.execute(f"update t set b = b + 1 where a = {random.randint(0,1000)}")
            connection.commit()
        except (Exception, Error) as error:
            connection.rollback()
            cursor = connection.cursor()
            i = i - 1 #retry
            #print("Error while connecting to PostgreSQL", error)
    print(f"ntries = {ntries}")
    connection.close()
    return 1
def f():
    return 1

if __name__ == '__main__':
    results = {}
    for nprocs in [1,2,4,6,8,12,16,20,30]:
        t1 = time.perf_counter()

        with concurrent.futures.ThreadPoolExecutor(max_workers = nprocs) as executor:
             futures = [executor.submit(run_xactions, psycopg2.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED) for _ in range(nprocs)]
             concurrent.futures.as_completed(futures)

        t2 = time.perf_counter()




        t3 = time.perf_counter()

        with concurrent.futures.ThreadPoolExecutor(max_workers = nprocs) as executor:
            futures = [executor.submit(run_xactions, psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE) for _ in range(nprocs)]
            concurrent.futures.as_completed(futures)


        t4 = time.perf_counter()
        print(f'n = {nprocs}, READ_UNCOMMITTED took {t2-t1}, SERIALIZABLE took {t4-t3}')
        results[nprocs] = [t2-t1, t4-t3]

vals =list((x[0] for x in results.values()))
plt.plot(results.keys(), vals, label ='Read Uncommitted')
vals =list((x[1] for x in results.values()))
plt.plot(results.keys(), vals, label ='Serializable')
plt.legend()
plt.xlabel("Number of Concurrent Threads")
plt.ylabel("Execution Time (s)")
plt.show()