Thief of Wealth

numGraphs = 3 # If multiple variables are plotted, this is how many graphs to plot per row

subWidth = 12 # Width to allocate for each row of subplots

subHeight = 6 # Height to allocate for each subplot

font = 8 # Font size for subplots


def plotNum(setType,dataset,fields):

    # Plots KDE plots of numerical data, to show distribution of frequencies

    # setType is a string, either "train" or "test"

    sns.set(style='darkgrid')

    n = len(fields) # Number of variables to plot

    if n == 1:

        field = fields[0]

        data = dataset[field].dropna()

        plt.xticks(rotation=90)

        sns.kdeplot(data).set_title("%s set %s" % (setType, field))

        #print("Minimum %s value: %d" % (field, data.min()))

        #print("Maximum %s value: %d" % (field, data.max()))

        #print("Average %s value: %d" % (field, (data.sum()/len(data))))

        #print("Median %s value: %d" % (field, data.median()))

    else:

        size = (subWidth, subHeight * math.ceil(n/numGraphs)) # Allot 4 in of height per row

        if n > numGraphs: 

            fig, axes = plt.subplots(math.ceil(n/numGraphs), numGraphs, figsize=size)

        else:

            fig, axes = plt.subplots(1, n, figsize=size)

        for i in range(n):

            field = fields[i]

            data = dataset[field].dropna()

            if n > numGraphs:

                sns.kdeplot(data,ax=axes[i//numGraphs,i % numGraphs]).set_title("%s set %s" % (setType, field))

            else:

                sns.kdeplot(data,ax=axes[i]).set_title("%s set %s" % (setType, field))

            #print("Minimum %s value: %d" % (field, data.min()))

            #print("Maximum %s value: %d" % (field, data.max()))

            #print("Average %s value: %d" % (field, (data.sum()/len(data))))

            #print("Median %s value: %d" % (field, data.median()))

        for ax in axes.flatten():

            for tick in ax.get_xticklabels():

                tick.set_rotation(90)

            ax.ticklabel_format(style='sci', axis='y', scilimits=(0,0))

    plt.show()

    return None



사용법
plotNum('train',train_combined,['TransactionAmt'])
plotNum(test인지 train인지,  데이터셋, feature배열)


profile on loading

Loading...