Through linear regression, we can predict values based on the corelation of variables. This can be very helpful in areas such as retail and e-commerce, where stores want to know the best selling price for their products.

There are many good tools that we can use to make linear regression implementations, such as PyTorch and TensorFlow. PyTorch is a Python machine learning package, it has two main features:

  1. Tensor computation (like NumPy) with strong GPU acceleration
  2. Automatic differentiation for building and training neural networks

In this project, we will use a vehicle dataset from cardekho, which is a company based in India that provides a car search platform that help customers to buy vehicles according to their needs. This dataset contains the following columns:

  1. year
  2. selling price
  3. showroom price
  4. fuel type
  5. seller type
  6. transmission
  7. number of previous owners

The dataset for this problem is taken from: https://www.kaggle.com/nehalbirla/vehicle-dataset-from-cardekho

We will create a model with the following steps:

  1. Download and explore the dataset
  2. Prepare the dataset for training
  3. Create a linear regression model
  4. Train the model to fit the data
  5. Make predictions using the trained model

In [1]:

# Import libraries
import torch
import torchvision
import torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F
import seaborn as sns
from torchvision.datasets.utils import download_url
from torch.utils.data import DataLoader, TensorDataset, random_split
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.
  import pandas.util.testing as tm

Step 1: Download and explore the data

Let us begin by downloading the data.

In [5]:

# Import .csv file
from google.colab import files
uploaded = files.upload()

Upload widget is only available when the cell has been executed in the current browser session. Please rerun this cell to enable.

Saving datasets_33080_43333_car_data.csv to datasets_33080_43333_car_data.csv

In [6]:

# Create Pandas Dataframe
import io
dataframe = pd.read_csv(io.BytesIO(uploaded['datasets_33080_43333_car_data.csv']))
dataframe = dataframe.drop(['Car_Name', 'Kms_Driven'], axis=1) # Exclude columns that we won't use
dataframe

Out[6]:

YearSelling_PricePresent_PriceFuel_TypeSeller_TypeTransmissionOwner
020143.355.59PetrolDealerManual0
120134.759.54DieselDealerManual0
220177.259.85PetrolDealerManual0
320112.854.15PetrolDealerManual0
420144.606.87DieselDealerManual0
29620169.5011.60DieselDealerManual0
29720154.005.90PetrolDealerManual0
29820093.3511.00PetrolDealerManual0
299201711.5012.50DieselDealerManual0
30020165.305.90PetrolDealerManual0

301 rows × 7 columns

In [0]:

num_rows = len(dataframe)
input_cols = ['Year',    'Selling_Price', 'Fuel_Type', 'Seller_Type',   'Transmission',  'Owner']
categorical_cols = ['Year', 'Fuel_Type', 'Seller_Type', 'Transmission', 'Owner']
output_cols = ['Present_Price']

In [8]:

print('Minimun Charge: {}'.format(dataframe.Present_Price.min()))
print('Average Charge: {}'.format(dataframe.Present_Price.mean()))
print('Maximun Charge: {}'.format(dataframe.Present_Price.max()))
Minimun Charge: 0.32
Average Charge: 7.628471760797344
Maximun Charge: 92.6

In [9]:

plt.title("Distribution of Showroom Price")
sns.distplot(dataframe.Present_Price, kde=False);

Step 2: Prepare the dataset for training

In [0]:

def dataframe_to_arrays(dataframe):
    # Make a copy of the original dataframe
    dataframe1 = dataframe.copy(deep=True)
    # Convert non-numeric categorical columns to numbers
    for col in categorical_cols:
        dataframe1[col] = dataframe1[col].astype('category').cat.codes
    # Extract input & outupts as numpy arrays
    inputs_array = dataframe1[input_cols].to_numpy()
    targets_array = dataframe1[output_cols].to_numpy()
    return inputs_array, targets_array

In [11]:

inputs_array, targets_array = dataframe_to_arrays(dataframe) # Create arrays for inputs and targets
inputs_array, targets_array

Out[11]:

(array([[11.  ,  3.35,  2.  ,  0.  ,  1.  ,  0.  ],
        [10.  ,  4.75,  1.  ,  0.  ,  1.  ,  0.  ],
        [14.  ,  7.25,  2.  ,  0.  ,  1.  ,  0.  ],
        ...,
        [ 6.  ,  3.35,  2.  ,  0.  ,  1.  ,  0.  ],
        [14.  , 11.5 ,  1.  ,  0.  ,  1.  ,  0.  ],
        [13.  ,  5.3 ,  2.  ,  0.  ,  1.  ,  0.  ]]), array([[ 5.59 ],
        [ 9.54 ],
        [ 9.85 ],
        [ 4.15 ],
        [ 6.87 ],
        [ 9.83 ],
        [ 8.12 ],
        [ 8.61 ],
        [ 8.89 ],
        [ 8.92 ],
        [ 3.6  ],
        [10.38 ],
        [ 9.94 ],
        [ 7.71 ],
        [ 7.21 ],
        [10.79 ],
        [10.79 ],
        [10.79 ],
        [ 5.09 ],
        [ 7.98 ],
        [ 3.95 ],
        [ 5.71 ],
        [ 8.01 ],
        [ 3.46 ],
        [ 4.41 ],
        [ 4.99 ],
        [ 5.87 ],
        [ 6.49 ],
        [ 3.95 ],
        [10.38 ],
        [ 5.98 ],
        [ 4.89 ],
        [ 7.49 ],
        [ 9.95 ],
        [ 8.06 ],
        [ 7.74 ],
        [ 7.2  ],
        [ 2.28 ],
        [ 3.76 ],
        [ 7.98 ],
        [ 7.87 ],
        [ 3.98 ],
        [ 7.15 ],
        [ 8.06 ],
        [ 2.69 ],
        [12.04 ],
        [ 4.89 ],
        [ 4.15 ],
        [ 7.71 ],
        [ 9.29 ],
        [30.61 ],
        [30.61 ],
        [19.77 ],
        [30.61 ],
        [10.21 ],
        [15.04 ],
        [ 7.27 ],
        [18.54 ],
        [ 6.8  ],
        [35.96 ],
        [18.61 ],
        [ 7.7  ],
        [35.96 ],
        [35.96 ],
        [36.23 ],
        [ 6.95 ],
        [23.15 ],
        [20.45 ],
        [13.74 ],
        [20.91 ],
        [ 6.76 ],
        [12.48 ],
        [18.61 ],
        [ 5.71 ],
        [ 8.93 ],
        [ 6.8  ],
        [14.68 ],
        [12.35 ],
        [22.83 ],
        [30.61 ],
        [14.89 ],
        [ 7.85 ],
        [25.39 ],
        [13.46 ],
        [13.46 ],
        [23.73 ],
        [92.6  ],
        [13.74 ],
        [ 6.05 ],
        [ 6.76 ],
        [18.61 ],
        [16.09 ],
        [13.7  ],
        [30.61 ],
        [22.78 ],
        [18.61 ],
        [25.39 ],
        [18.64 ],
        [18.61 ],
        [20.45 ],
        [ 1.9  ],
        [ 1.82 ],
        [ 1.78 ],
        [ 1.6  ],
        [ 1.47 ],
        [ 2.37 ],
        [ 3.45 ],
        [ 1.5  ],
        [ 1.5  ],
        [ 1.47 ],
        [ 1.78 ],
        [ 1.5  ],
        [ 2.4  ],
        [ 1.4  ],
        [ 1.47 ],
        [ 1.47 ],
        [ 1.47 ],
        [ 1.9  ],
        [ 1.47 ],
        [ 1.9  ],
        [ 1.26 ],
        [ 1.5  ],
        [ 1.17 ],
        [ 1.47 ],
        [ 1.75 ],
        [ 1.75 ],
        [ 0.95 ],
        [ 0.8  ],
        [ 0.87 ],
        [ 0.84 ],
        [ 0.87 ],
        [ 0.82 ],
        [ 0.95 ],
        [ 0.95 ],
        [ 0.81 ],
        [ 0.74 ],
        [ 1.2  ],
        [ 0.787],
        [ 0.87 ],
        [ 0.95 ],
        [ 1.2  ],
        [ 0.8  ],
        [ 0.84 ],
        [ 0.84 ],
        [ 0.99 ],
        [ 0.81 ],
        [ 0.787],
        [ 0.84 ],
        [ 0.94 ],
        [ 0.94 ],
        [ 0.826],
        [ 0.55 ],
        [ 0.99 ],
        [ 0.99 ],
        [ 0.88 ],
        [ 0.51 ],
        [ 0.52 ],
        [ 0.84 ],
        [ 0.54 ],
        [ 0.51 ],
        [ 0.95 ],
        [ 0.826],
        [ 0.99 ],
        [ 0.95 ],
        [ 0.54 ],
        [ 0.54 ],
        [ 0.55 ],
        [ 0.81 ],
        [ 0.73 ],
        [ 0.54 ],
        [ 0.83 ],
        [ 0.55 ],
        [ 0.64 ],
        [ 0.51 ],
        [ 0.72 ],
        [ 0.787],
        [ 1.05 ],
        [ 0.57 ],
        [ 0.52 ],
        [ 1.05 ],
        [ 0.51 ],
        [ 0.48 ],
        [ 0.58 ],
        [ 0.47 ],
        [ 0.75 ],
        [ 0.58 ],
        [ 0.52 ],
        [ 0.51 ],
        [ 0.57 ],
        [ 0.57 ],
        [ 0.75 ],
        [ 0.57 ],
        [ 0.75 ],
        [ 0.65 ],
        [ 0.787],
        [ 0.32 ],
        [ 0.52 ],
        [ 0.51 ],
        [ 0.57 ],
        [ 0.58 ],
        [ 0.75 ],
        [ 6.79 ],
        [ 5.7  ],
        [ 4.6  ],
        [ 4.43 ],
        [ 5.7  ],
        [ 7.13 ],
        [ 5.7  ],
        [ 8.1  ],
        [ 5.7  ],
        [ 4.6  ],
        [14.79 ],
        [13.6  ],
        [ 6.79 ],
        [ 5.7  ],
        [ 9.4  ],
        [ 4.43 ],
        [ 4.43 ],
        [ 9.4  ],
        [ 9.4  ],
        [ 4.43 ],
        [ 6.79 ],
        [ 7.6  ],
        [ 9.4  ],
        [ 9.4  ],
        [ 4.6  ],
        [ 5.7  ],
        [ 4.43 ],
        [ 9.4  ],
        [ 6.79 ],
        [ 9.4  ],
        [ 9.4  ],
        [14.79 ],
        [ 5.7  ],
        [ 5.7  ],
        [ 9.4  ],
        [ 4.43 ],
        [13.6  ],
        [ 9.4  ],
        [ 4.43 ],
        [ 9.4  ],
        [ 7.13 ],
        [ 7.13 ],
        [ 7.6  ],
        [ 9.4  ],
        [ 9.4  ],
        [ 6.79 ],
        [ 9.4  ],
        [ 4.6  ],
        [ 7.6  ],
        [13.6  ],
        [ 9.9  ],
        [ 6.82 ],
        [ 9.9  ],
        [ 9.9  ],
        [ 5.35 ],
        [13.6  ],
        [13.6  ],
        [13.6  ],
        [ 7.   ],
        [13.6  ],
        [ 5.97 ],
        [ 5.8  ],
        [ 7.7  ],
        [ 7.   ],
        [ 8.7  ],
        [ 7.   ],
        [ 9.4  ],
        [ 5.8  ],
        [10.   ],
        [10.   ],
        [10.   ],
        [10.   ],
        [ 7.5  ],
        [ 6.8  ],
        [13.6  ],
        [13.6  ],
        [13.6  ],
        [ 8.4  ],
        [13.6  ],
        [ 5.9  ],
        [ 7.6  ],
        [14.   ],
        [11.8  ],
        [ 5.9  ],
        [ 8.5  ],
        [ 7.9  ],
        [ 7.5  ],
        [13.6  ],
        [13.6  ],
        [ 6.4  ],
        [ 6.1  ],
        [ 8.4  ],
        [ 9.9  ],
        [ 6.8  ],
        [13.09 ],
        [11.6  ],
        [ 5.9  ],
        [11.   ],
        [12.5  ],
        [ 5.9  ]]))

In [0]:

# Create PyTorch Tensors from Numpy arrays
inputs = torch.from_numpy(inputs_array).float()
targets = torch.from_numpy(targets_array).float()

Next, we need to create PyTorch datasets & data loaders for training & validation. We’ll start by creating a TensorDataset.

In [0]:

dataset = TensorDataset(inputs, targets)

In [0]:

val_percent = 0.1
val_size = int(num_rows * val_percent)
train_size = num_rows - val_size

train_ds, val_ds = random_split(dataset, [train_size, val_size]) # split dataset into 2 parts of the desired length

Finally, we can create data loaders for training & validation.

In [0]:

batch_size = 50
train_loader = DataLoader(train_ds, batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size)

In [16]:

for xb, yb in train_loader:
    print("inputs:", xb)
    print("targets:", yb)
    break
inputs: tensor([[ 5.0000,  0.2000,  2.0000,  1.0000,  1.0000,  0.0000],
        [12.0000,  6.8500,  1.0000,  0.0000,  1.0000,  0.0000],
        [ 3.0000,  1.0500,  2.0000,  0.0000,  1.0000,  0.0000],
        [ 6.0000,  0.9000,  2.0000,  1.0000,  1.0000,  0.0000],
        [ 9.0000,  4.9500,  1.0000,  0.0000,  1.0000,  0.0000],
        [13.0000,  1.0500,  2.0000,  1.0000,  1.0000,  0.0000],
        [ 6.0000,  3.3500,  2.0000,  0.0000,  1.0000,  0.0000],
        [11.0000,  2.5500,  2.0000,  0.0000,  1.0000,  0.0000],
        [12.0000,  0.4000,  2.0000,  1.0000,  1.0000,  0.0000],
        [11.0000,  2.5000,  2.0000,  0.0000,  1.0000,  0.0000],
        [13.0000, 20.7500,  1.0000,  0.0000,  0.0000,  0.0000],
        [ 8.0000,  3.0000,  2.0000,  0.0000,  1.0000,  0.0000],
        [ 5.0000,  4.0000,  2.0000,  0.0000,  0.0000,  0.0000],
        [13.0000,  0.7500,  2.0000,  1.0000,  1.0000,  0.0000],
        [13.0000,  6.6000,  2.0000,  0.0000,  1.0000,  0.0000],
        [11.0000,  3.7500,  2.0000,  0.0000,  1.0000,  0.0000],
        [11.0000,  3.3500,  2.0000,  0.0000,  1.0000,  0.0000],
        [ 9.0000, 14.5000,  1.0000,  0.0000,  0.0000,  0.0000],
        [13.0000,  6.0000,  2.0000,  0.0000,  1.0000,  0.0000],
        [10.0000,  1.2500,  2.0000,  1.0000,  1.0000,  0.0000],
        [13.0000,  0.4500,  2.0000,  1.0000,  0.0000,  0.0000],
        [12.0000,  4.6500,  2.0000,  0.0000,  1.0000,  0.0000],
        [12.0000, 23.5000,  1.0000,  0.0000,  0.0000,  0.0000],
        [ 2.0000,  2.7500,  2.0000,  1.0000,  1.0000,  0.0000],
        [ 3.0000,  2.1000,  2.0000,  0.0000,  1.0000,  0.0000],
        [13.0000, 14.7300,  1.0000,  0.0000,  1.0000,  0.0000],
        [11.0000,  3.6500,  2.0000,  0.0000,  1.0000,  0.0000],
        [10.0000,  0.2700,  2.0000,  1.0000,  1.0000,  0.0000],
        [10.0000,  0.6500,  2.0000,  1.0000,  1.0000,  0.0000],
        [11.0000,  8.2500,  1.0000,  0.0000,  1.0000,  0.0000],
        [13.0000,  0.6000,  2.0000,  1.0000,  1.0000,  0.0000],
        [ 9.0000,  2.0000,  2.0000,  0.0000,  1.0000,  0.0000],
        [11.0000,  3.9500,  1.0000,  0.0000,  1.0000,  0.0000],
        [ 7.0000, 35.0000,  1.0000,  0.0000,  1.0000,  0.0000],
        [13.0000,  6.4000,  2.0000,  0.0000,  1.0000,  0.0000],
        [10.0000,  2.6500,  2.0000,  0.0000,  1.0000,  0.0000],
        [10.0000,  6.9500,  2.0000,  0.0000,  1.0000,  0.0000],
        [10.0000,  0.4200,  2.0000,  1.0000,  1.0000,  0.0000],
        [10.0000,  5.1100,  2.0000,  0.0000,  0.0000,  0.0000],
        [ 5.0000,  0.2000,  2.0000,  1.0000,  1.0000,  0.0000],
        [14.0000, 23.0000,  1.0000,  0.0000,  0.0000,  0.0000],
        [11.0000,  5.3000,  2.0000,  0.0000,  1.0000,  0.0000],
        [12.0000,  1.1000,  2.0000,  1.0000,  1.0000,  0.0000],
        [ 8.0000,  0.3500,  2.0000,  1.0000,  1.0000,  0.0000],
        [ 7.0000,  2.6500,  2.0000,  0.0000,  1.0000,  0.0000],
        [11.0000,  0.3500,  2.0000,  1.0000,  0.0000,  0.0000],
        [14.0000,  0.4000,  2.0000,  1.0000,  0.0000,  0.0000],
        [13.0000, 11.2500,  2.0000,  0.0000,  1.0000,  0.0000],
        [ 6.0000,  2.2500,  2.0000,  0.0000,  1.0000,  0.0000],
        [ 3.0000,  2.5000,  2.0000,  1.0000,  0.0000,  2.0000]])
targets: tensor([[ 0.7500],
        [10.3800],
        [ 4.1500],
        [ 1.7500],
        [ 9.4000],
        [ 1.2600],
        [11.0000],
        [ 3.9800],
        [ 0.5500],
        [ 3.4600],
        [25.3900],
        [ 4.9900],
        [22.7800],
        [ 0.8000],
        [ 7.7000],
        [ 6.8000],
        [ 5.5900],
        [30.6100],
        [ 8.4000],
        [ 1.5000],
        [ 0.5400],
        [ 7.2000],
        [35.9600],
        [10.2100],
        [ 7.6000],
        [14.8900],
        [ 7.0000],
        [ 0.4700],
        [ 0.7870],
        [14.0000],
        [ 0.8700],
        [ 4.4300],
        [ 6.7600],
        [92.6000],
        [ 8.4000],
        [ 4.8900],
        [18.6100],
        [ 0.7300],
        [ 9.4000],
        [ 0.7870],
        [25.3900],
        [ 6.8000],
        [ 1.4700],
        [ 1.0500],
        [ 7.9800],
        [ 0.5200],
        [ 0.5100],
        [13.6000],
        [ 7.2100],
        [23.7300]])

Step 3: Create a Linear Regression Model

In [0]:

input_size = len(input_cols)
output_size = len(output_cols)

In [0]:

class PriceModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(input_size, output_size)

    def forward(self, xb):
        out = self.linear(xb)
        return out

    def training_step(self, batch):
        inputs, targets = batch 
        # Generate predictions
        out = self(inputs)          
        # Calculate loss
        loss = F.l1_loss(out, targets)
        return loss

    def validation_step(self, batch):
        inputs, targets = batch
        # Generate predictions
        out = self(inputs)
        # Calculate loss
        loss = F.l1_loss(out, targets)  
        return {'val_loss': loss.detach()}

    def validation_epoch_end(self, outputs):
        batch_losses = [x['val_loss'] for x in outputs]
        epoch_loss = torch.stack(batch_losses).mean()
        return {'val_loss': epoch_loss.item()}

    def epoch_end(self, epoch, result, num_epochs):
        # Print result every 20th epoch
        if (epoch+1) % 20 == 0 or epoch == num_epochs-1:
            print("Epoch [{}], val_loss: {:.4f}".format(epoch+1, result['val_loss']))

In [0]:

model = PriceModel()

Let’s check out the weights and biases of the model using model.parameters.

In [20]:

list(model.parameters())

Out[20]:

[Parameter containing:
 tensor([[-0.3935, -0.3412, -0.0166, -0.0592,  0.0338, -0.1588]],
        requires_grad=True), Parameter containing:
 tensor([-0.3614], requires_grad=True)]

Step 4: Train the model to fit the data

In [0]:

def evaluate(model, val_loader):
    outputs = [model.validation_step(batch) for batch in val_loader]
    return model.validation_epoch_end(outputs)

def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD):
    history = []
    optimizer = opt_func(model.parameters(), lr)
    for epoch in range(epochs):
        # Training Phase 
        for batch in train_loader:
            loss = model.training_step(batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # Validation phase
        result = evaluate(model, val_loader)
        model.epoch_end(epoch, result, epochs)
        history.append(result)
    return history

In [22]:

result = evaluate(model, val_loader) # Use the evaluate function
print(result)
{'val_loss': 13.508749961853027}

In [23]:

epochs = 100
lr = 1e-2
history1 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.6579
Epoch [40], val_loss: 1.4193
Epoch [60], val_loss: 1.4849
Epoch [80], val_loss: 1.3621
Epoch [100], val_loss: 1.4464

In [24]:

epochs = 100
lr = 1e-3
history2 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.3564
Epoch [40], val_loss: 1.3603
Epoch [60], val_loss: 1.3568
Epoch [80], val_loss: 1.3567
Epoch [100], val_loss: 1.3553

In [25]:

epochs = 100
lr = 1e-4
history3 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.3552
Epoch [40], val_loss: 1.3548
Epoch [60], val_loss: 1.3551
Epoch [80], val_loss: 1.3549
Epoch [100], val_loss: 1.3550

In [26]:

epochs = 100
lr = 1e-5
history4 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.3550
Epoch [40], val_loss: 1.3550
Epoch [60], val_loss: 1.3549
Epoch [80], val_loss: 1.3550
Epoch [100], val_loss: 1.3550

In [27]:

epochs = 100
lr = 1e-6
history5 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.3550
Epoch [40], val_loss: 1.3550
Epoch [60], val_loss: 1.3550
Epoch [80], val_loss: 1.3550
Epoch [100], val_loss: 1.3550

In [28]:

val_loss = [result] + history1 + history2 + history3 + history4 + history5
print(val_loss)
val_loss_list = [vl['val_loss'] for vl in val_loss]

plt.plot(val_loss_list, '-x')

plt.xlabel('epochs')
plt.ylabel('losses')
[{'val_loss': 13.508749961853027}, {'val_loss': 6.118033409118652}, {'val_loss': 4.425381660461426}, {'val_loss': 3.5121304988861084}, {'val_loss': 3.141930103302002}, {'val_loss': 2.6487579345703125}, {'val_loss': 2.2823987007141113}, {'val_loss': 2.0026051998138428}, {'val_loss': 1.78737473487854}, {'val_loss': 1.600001573562622}, {'val_loss': 1.4982473850250244}, {'val_loss': 1.4682856798171997}, {'val_loss': 1.4618252515792847}, {'val_loss': 1.4482910633087158}, {'val_loss': 1.5700268745422363}, {'val_loss': 1.449874758720398}, {'val_loss': 1.4577511548995972}, {'val_loss': 1.4320111274719238}, {'val_loss': 1.4231066703796387}, {'val_loss': 1.4389463663101196}, {'val_loss': 1.65794038772583}, {'val_loss': 1.4212898015975952}, {'val_loss': 1.4200899600982666}, {'val_loss': 1.435904860496521}, {'val_loss': 1.5388320684432983}, {'val_loss': 1.4144498109817505}, {'val_loss': 1.4971991777420044}, {'val_loss': 1.5992761850357056}, {'val_loss': 1.4277279376983643}, {'val_loss': 1.5031235218048096}, {'val_loss': 1.4794392585754395}, {'val_loss': 1.4153716564178467}, {'val_loss': 1.4078534841537476}, {'val_loss': 1.510008454322815}, {'val_loss': 1.4809588193893433}, {'val_loss': 1.4999046325683594}, {'val_loss': 1.3946417570114136}, {'val_loss': 1.6055680513381958}, {'val_loss': 1.4423900842666626}, {'val_loss': 1.4460539817810059}, {'val_loss': 1.419303297996521}, {'val_loss': 1.3937510251998901}, {'val_loss': 1.4011147022247314}, {'val_loss': 1.4060081243515015}, {'val_loss': 1.3984336853027344}, {'val_loss': 1.4633814096450806}, {'val_loss': 1.3751095533370972}, {'val_loss': 1.4318023920059204}, {'val_loss': 1.6094945669174194}, {'val_loss': 1.3969550132751465}, {'val_loss': 1.7303458452224731}, {'val_loss': 1.3911798000335693}, {'val_loss': 1.3729671239852905}, {'val_loss': 1.54682457447052}, {'val_loss': 1.3787986040115356}, {'val_loss': 1.649728536605835}, {'val_loss': 1.3683972358703613}, {'val_loss': 1.367186188697815}, {'val_loss': 1.4891501665115356}, {'val_loss': 1.3684563636779785}, {'val_loss': 1.4849108457565308}, {'val_loss': 1.37216055393219}, {'val_loss': 1.3770359754562378}, {'val_loss': 1.3935377597808838}, {'val_loss': 1.3642178773880005}, {'val_loss': 1.4782458543777466}, {'val_loss': 1.4224978685379028}, {'val_loss': 1.4049519300460815}, {'val_loss': 1.4761677980422974}, {'val_loss': 1.5544482469558716}, {'val_loss': 1.3702653646469116}, {'val_loss': 1.4356273412704468}, {'val_loss': 1.400244951248169}, {'val_loss': 1.377777338027954}, {'val_loss': 1.448345422744751}, {'val_loss': 1.3617547750473022}, {'val_loss': 1.6235082149505615}, {'val_loss': 1.4901325702667236}, {'val_loss': 1.542846441268921}, {'val_loss': 1.360831379890442}, {'val_loss': 1.3620823621749878}, {'val_loss': 1.3667073249816895}, {'val_loss': 1.3608430624008179}, {'val_loss': 1.5464335680007935}, {'val_loss': 1.433320164680481}, {'val_loss': 1.3586939573287964}, {'val_loss': 1.3673779964447021}, {'val_loss': 1.3756794929504395}, {'val_loss': 1.4192850589752197}, {'val_loss': 1.372723937034607}, {'val_loss': 1.4133044481277466}, {'val_loss': 1.3854647874832153}, {'val_loss': 1.364600658416748}, {'val_loss': 1.3931787014007568}, {'val_loss': 1.4342844486236572}, {'val_loss': 1.3628737926483154}, {'val_loss': 1.5067977905273438}, {'val_loss': 1.3583155870437622}, {'val_loss': 1.356598138809204}, {'val_loss': 1.37332022190094}, {'val_loss': 1.4464164972305298}, {'val_loss': 1.3635132312774658}, {'val_loss': 1.3564374446868896}, {'val_loss': 1.3569012880325317}, {'val_loss': 1.3573507070541382}, {'val_loss': 1.3600316047668457}, {'val_loss': 1.3589403629302979}, {'val_loss': 1.3563119173049927}, {'val_loss': 1.3571209907531738}, {'val_loss': 1.356133222579956}, {'val_loss': 1.356939435005188}, {'val_loss': 1.359302282333374}, {'val_loss': 1.3567982912063599}, {'val_loss': 1.357097864151001}, {'val_loss': 1.3570212125778198}, {'val_loss': 1.3584320545196533}, {'val_loss': 1.3569546937942505}, {'val_loss': 1.357675313949585}, {'val_loss': 1.3573272228240967}, {'val_loss': 1.3567001819610596}, {'val_loss': 1.356406807899475}, {'val_loss': 1.3568202257156372}, {'val_loss': 1.360009789466858}, {'val_loss': 1.3571254014968872}, {'val_loss': 1.357373595237732}, {'val_loss': 1.3569358587265015}, {'val_loss': 1.3566490411758423}, {'val_loss': 1.358595609664917}, {'val_loss': 1.356493353843689}, {'val_loss': 1.3569475412368774}, {'val_loss': 1.3568549156188965}, {'val_loss': 1.3565493822097778}, {'val_loss': 1.3592419624328613}, {'val_loss': 1.3565335273742676}, {'val_loss': 1.3574392795562744}, {'val_loss': 1.3569074869155884}, {'val_loss': 1.3572746515274048}, {'val_loss': 1.3571385145187378}, {'val_loss': 1.357692003250122}, {'val_loss': 1.3599610328674316}, {'val_loss': 1.3603343963623047}, {'val_loss': 1.3565924167633057}, {'val_loss': 1.3568990230560303}, {'val_loss': 1.3564502000808716}, {'val_loss': 1.3568381071090698}, {'val_loss': 1.3563752174377441}, {'val_loss': 1.3565641641616821}, {'val_loss': 1.3578929901123047}, {'val_loss': 1.3564704656600952}, {'val_loss': 1.3579374551773071}, {'val_loss': 1.3567204475402832}, {'val_loss': 1.358005404472351}, {'val_loss': 1.3587855100631714}, {'val_loss': 1.3610790967941284}, {'val_loss': 1.3572267293930054}, {'val_loss': 1.3567177057266235}, {'val_loss': 1.3564910888671875}, {'val_loss': 1.3565599918365479}, {'val_loss': 1.3571439981460571}, {'val_loss': 1.3573533296585083}, {'val_loss': 1.3567684888839722}, {'val_loss': 1.3597580194473267}, {'val_loss': 1.3570773601531982}, {'val_loss': 1.3559664487838745}, {'val_loss': 1.3565402030944824}, {'val_loss': 1.356568694114685}, {'val_loss': 1.3579165935516357}, {'val_loss': 1.3558008670806885}, {'val_loss': 1.3559842109680176}, {'val_loss': 1.3560093641281128}, {'val_loss': 1.355743169784546}, {'val_loss': 1.3560271263122559}, {'val_loss': 1.3596975803375244}, {'val_loss': 1.3558939695358276}, {'val_loss': 1.3556714057922363}, {'val_loss': 1.3555657863616943}, {'val_loss': 1.3557069301605225}, {'val_loss': 1.3573758602142334}, {'val_loss': 1.3586246967315674}, {'val_loss': 1.3611829280853271}, {'val_loss': 1.3566774129867554}, {'val_loss': 1.3559454679489136}, {'val_loss': 1.357486367225647}, {'val_loss': 1.3554534912109375}, {'val_loss': 1.3553087711334229}, {'val_loss': 1.3563026189804077}, {'val_loss': 1.3549476861953735}, {'val_loss': 1.3553135395050049}, {'val_loss': 1.3554954528808594}, {'val_loss': 1.3618000745773315}, {'val_loss': 1.3565247058868408}, {'val_loss': 1.3572131395339966}, {'val_loss': 1.3555357456207275}, {'val_loss': 1.3553812503814697}, {'val_loss': 1.3570847511291504}, {'val_loss': 1.3589755296707153}, {'val_loss': 1.3560833930969238}, {'val_loss': 1.3556828498840332}, {'val_loss': 1.3555214405059814}, {'val_loss': 1.3567246198654175}, {'val_loss': 1.3553409576416016}, {'val_loss': 1.3552978038787842}, {'val_loss': 1.355290174484253}, {'val_loss': 1.3553277254104614}, {'val_loss': 1.355329155921936}, {'val_loss': 1.3553335666656494}, {'val_loss': 1.3554489612579346}, {'val_loss': 1.3553740978240967}, {'val_loss': 1.3554728031158447}, {'val_loss': 1.3555446863174438}, {'val_loss': 1.355400800704956}, {'val_loss': 1.35552179813385}, {'val_loss': 1.3552525043487549}, {'val_loss': 1.3552584648132324}, {'val_loss': 1.3552162647247314}, {'val_loss': 1.3552507162094116}, {'val_loss': 1.35527765750885}, {'val_loss': 1.3552875518798828}, {'val_loss': 1.3552402257919312}, {'val_loss': 1.355261206626892}, {'val_loss': 1.3551791906356812}, {'val_loss': 1.3551557064056396}, {'val_loss': 1.355139970779419}, {'val_loss': 1.355112075805664}, {'val_loss': 1.355059266090393}, {'val_loss': 1.355006456375122}, {'val_loss': 1.3549349308013916}, {'val_loss': 1.3549612760543823}, {'val_loss': 1.354889154434204}, {'val_loss': 1.3548805713653564}, {'val_loss': 1.3548870086669922}, {'val_loss': 1.3550840616226196}, {'val_loss': 1.355032205581665}, {'val_loss': 1.355126976966858}, {'val_loss': 1.3553258180618286}, {'val_loss': 1.35527503490448}, {'val_loss': 1.3548399209976196}, {'val_loss': 1.354857087135315}, {'val_loss': 1.3548732995986938}, {'val_loss': 1.3548439741134644}, {'val_loss': 1.3548471927642822}, {'val_loss': 1.354846477508545}, {'val_loss': 1.3549150228500366}, {'val_loss': 1.3548537492752075}, {'val_loss': 1.3549411296844482}, {'val_loss': 1.3550673723220825}, {'val_loss': 1.3551135063171387}, {'val_loss': 1.3551299571990967}, {'val_loss': 1.3551310300827026}, {'val_loss': 1.355109453201294}, {'val_loss': 1.3552041053771973}, {'val_loss': 1.3551609516143799}, {'val_loss': 1.3550647497177124}, {'val_loss': 1.3550242185592651}, {'val_loss': 1.3549954891204834}, {'val_loss': 1.354806661605835}, {'val_loss': 1.3548814058303833}, {'val_loss': 1.3550723791122437}, {'val_loss': 1.355039119720459}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.3550713062286377}, {'val_loss': 1.3548802137374878}, {'val_loss': 1.3550053834915161}, {'val_loss': 1.355177640914917}, {'val_loss': 1.3551442623138428}, {'val_loss': 1.3550387620925903}, {'val_loss': 1.354960322380066}, {'val_loss': 1.3548640012741089}, {'val_loss': 1.3548303842544556}, {'val_loss': 1.3547861576080322}, {'val_loss': 1.354933500289917}, {'val_loss': 1.3548729419708252}, {'val_loss': 1.3549764156341553}, {'val_loss': 1.3550817966461182}, {'val_loss': 1.3551771640777588}, {'val_loss': 1.3549935817718506}, {'val_loss': 1.354945421218872}, {'val_loss': 1.3550140857696533}, {'val_loss': 1.355054259300232}, {'val_loss': 1.3547836542129517}, {'val_loss': 1.3548635244369507}, {'val_loss': 1.3549376726150513}, {'val_loss': 1.3549668788909912}, {'val_loss': 1.3548489809036255}, {'val_loss': 1.354942798614502}, {'val_loss': 1.3550866842269897}, {'val_loss': 1.3550900220870972}, {'val_loss': 1.3549708127975464}, {'val_loss': 1.3551058769226074}, {'val_loss': 1.3551361560821533}, {'val_loss': 1.355053424835205}, {'val_loss': 1.354856014251709}, {'val_loss': 1.3548083305358887}, {'val_loss': 1.3548500537872314}, {'val_loss': 1.354717493057251}, {'val_loss': 1.3547502756118774}, {'val_loss': 1.3547343015670776}, {'val_loss': 1.3547616004943848}, {'val_loss': 1.3547896146774292}, {'val_loss': 1.3548692464828491}, {'val_loss': 1.3549540042877197}, {'val_loss': 1.3549708127975464}, {'val_loss': 1.3549656867980957}, {'val_loss': 1.3549615144729614}, {'val_loss': 1.35495924949646}, {'val_loss': 1.3549537658691406}, {'val_loss': 1.3549537658691406}, {'val_loss': 1.354970097541809}, {'val_loss': 1.3549647331237793}, {'val_loss': 1.3549704551696777}, {'val_loss': 1.3549609184265137}, {'val_loss': 1.3549612760543823}, {'val_loss': 1.3549517393112183}, {'val_loss': 1.3549433946609497}, {'val_loss': 1.3549416065216064}, {'val_loss': 1.354941964149475}, {'val_loss': 1.354957938194275}, {'val_loss': 1.3549760580062866}, {'val_loss': 1.3549773693084717}, {'val_loss': 1.3549883365631104}, {'val_loss': 1.3549954891204834}, {'val_loss': 1.3550008535385132}, {'val_loss': 1.3549926280975342}, {'val_loss': 1.3549751043319702}, {'val_loss': 1.3549761772155762}, {'val_loss': 1.3549797534942627}, {'val_loss': 1.3549789190292358}, {'val_loss': 1.3549652099609375}, {'val_loss': 1.3549779653549194}, {'val_loss': 1.3549611568450928}, {'val_loss': 1.3549631834030151}, {'val_loss': 1.3549692630767822}, {'val_loss': 1.354960560798645}, {'val_loss': 1.3549777269363403}, {'val_loss': 1.3549540042877197}, {'val_loss': 1.3549739122390747}, {'val_loss': 1.3549824953079224}, {'val_loss': 1.354980707168579}, {'val_loss': 1.3549729585647583}, {'val_loss': 1.3549515008926392}, {'val_loss': 1.354953408241272}, {'val_loss': 1.3549593687057495}, {'val_loss': 1.354970932006836}, {'val_loss': 1.3549659252166748}, {'val_loss': 1.3549597263336182}, {'val_loss': 1.3549662828445435}, {'val_loss': 1.354949951171875}, {'val_loss': 1.3549219369888306}, {'val_loss': 1.354909062385559}, {'val_loss': 1.3549017906188965}, {'val_loss': 1.354910969734192}, {'val_loss': 1.3549164533615112}, {'val_loss': 1.3549116849899292}, {'val_loss': 1.3549190759658813}, {'val_loss': 1.354902982711792}, {'val_loss': 1.3549138307571411}, {'val_loss': 1.3549312353134155}, {'val_loss': 1.3549350500106812}, {'val_loss': 1.354937195777893}, {'val_loss': 1.3549233675003052}, {'val_loss': 1.3549250364303589}, {'val_loss': 1.3549437522888184}, {'val_loss': 1.354934573173523}, {'val_loss': 1.3549232482910156}, {'val_loss': 1.3549304008483887}, {'val_loss': 1.354939579963684}, {'val_loss': 1.3549563884735107}, {'val_loss': 1.354954481124878}, {'val_loss': 1.3549622297286987}, {'val_loss': 1.3549681901931763}, {'val_loss': 1.354967713356018}, {'val_loss': 1.3549511432647705}, {'val_loss': 1.3549466133117676}, {'val_loss': 1.3549611568450928}, {'val_loss': 1.3549678325653076}, {'val_loss': 1.3549621105194092}, {'val_loss': 1.3549638986587524}, {'val_loss': 1.3549736738204956}, {'val_loss': 1.3549846410751343}, {'val_loss': 1.3549665212631226}, {'val_loss': 1.3549554347991943}, {'val_loss': 1.3549529314041138}, {'val_loss': 1.35495126247406}, {'val_loss': 1.3549467325210571}, {'val_loss': 1.3549555540084839}, {'val_loss': 1.3549621105194092}, {'val_loss': 1.354956030845642}, {'val_loss': 1.3549562692642212}, {'val_loss': 1.3549613952636719}, {'val_loss': 1.3549690246582031}, {'val_loss': 1.3550021648406982}, {'val_loss': 1.354974389076233}, {'val_loss': 1.3549872636795044}, {'val_loss': 1.354970097541809}, {'val_loss': 1.3549933433532715}, {'val_loss': 1.355014443397522}, {'val_loss': 1.3550028800964355}, {'val_loss': 1.3550028800964355}, {'val_loss': 1.3550089597702026}, {'val_loss': 1.3550013303756714}, {'val_loss': 1.3550052642822266}, {'val_loss': 1.3550039529800415}, {'val_loss': 1.3550015687942505}, {'val_loss': 1.3550002574920654}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549987077713013}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.354996919631958}, {'val_loss': 1.354996919631958}, {'val_loss': 1.3549970388412476}, {'val_loss': 1.3549975156784058}, {'val_loss': 1.3549970388412476}, {'val_loss': 1.354995846748352}, {'val_loss': 1.3549954891204834}, {'val_loss': 1.3549965620040894}, {'val_loss': 1.3549964427947998}, {'val_loss': 1.3549935817718506}, {'val_loss': 1.3549926280975342}, {'val_loss': 1.3549935817718506}, {'val_loss': 1.354992389678955}, {'val_loss': 1.354993224143982}, {'val_loss': 1.3549939393997192}, {'val_loss': 1.354995846748352}, {'val_loss': 1.3549952507019043}, {'val_loss': 1.3549952507019043}, {'val_loss': 1.3549950122833252}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.3549987077713013}, {'val_loss': 1.3549998998641968}, {'val_loss': 1.3549989461898804}, {'val_loss': 1.3550000190734863}, {'val_loss': 1.3549987077713013}, {'val_loss': 1.3549976348876953}, {'val_loss': 1.3549994230270386}, {'val_loss': 1.3549989461898804}, {'val_loss': 1.3549981117248535}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.3549996614456177}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.3549973964691162}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.354998230934143}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549960851669312}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.354997992515564}, {'val_loss': 1.3549962043762207}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.354996919631958}, {'val_loss': 1.3549964427947998}, {'val_loss': 1.3549950122833252}, {'val_loss': 1.3549957275390625}, {'val_loss': 1.3549973964691162}, {'val_loss': 1.3549975156784058}, {'val_loss': 1.3549976348876953}, {'val_loss': 1.354999303817749}, {'val_loss': 1.3550000190734863}, {'val_loss': 1.3549994230270386}, {'val_loss': 1.3549977540969849}, {'val_loss': 1.3549987077713013}, {'val_loss': 1.3550001382827759}, {'val_loss': 1.3550000190734863}, {'val_loss': 1.3549977540969849}, {'val_loss': 1.3549970388412476}, {'val_loss': 1.3549978733062744}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.354997158050537}, {'val_loss': 1.3549964427947998}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549983501434326}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549996614456177}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3550015687942505}, {'val_loss': 1.355000376701355}, {'val_loss': 1.3549975156784058}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549977540969849}, {'val_loss': 1.3549978733062744}, {'val_loss': 1.3549991846084595}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.3549978733062744}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.3549995422363281}, {'val_loss': 1.354999303817749}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.354998230934143}, {'val_loss': 1.354998230934143}, {'val_loss': 1.3549978733062744}, {'val_loss': 1.354997992515564}, {'val_loss': 1.3549975156784058}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549996614456177}]

Out[28]:

Text(0, 0.5, 'losses')

Step 5: Make predictions using the trained model

In [0]:

def predict_single(input, target, model):
    inputs = input.unsqueeze(0)
    predictions = model(inputs)
    prediction = predictions[0].detach()
    print("Input:", input)
    print("Target:", target)
    print("Prediction:", prediction)

In [30]:

input, target = val_ds[0]
predict_single(input, target, model)
Input: tensor([12.0000,  8.2500,  1.0000,  0.0000,  1.0000,  0.0000])
Target: tensor([9.4000])
Prediction: tensor([10.9955])

In [31]:

input, target = val_ds[10]
predict_single(input, target, model)
Input: tensor([12.0000,  5.2500,  2.0000,  0.0000,  1.0000,  0.0000])
Target: tensor([5.9000])
Prediction: tensor([7.5854])

In [32]:

input, target = val_ds[23]
predict_single(input, target, model)
Input: tensor([11.0000,  4.9000,  1.0000,  0.0000,  1.0000,  0.0000])
Target: tensor([8.9300])
Prediction: tensor([6.5628])