Cars price prediction through linear regression with PyTorch
Through linear regression, we can predict values based on the corelation of variables. This can be very helpful in areas such as retail and e-commerce, where stores want to know the best selling price for their products.
There are many good tools that we can use to make linear regression implementations, such as PyTorch and TensorFlow. PyTorch is a Python machine learning package, it has two main features:
- Tensor computation (like NumPy) with strong GPU acceleration
- Automatic differentiation for building and training neural networks
In this project, we will use a vehicle dataset from cardekho, which is a company based in India that provides a car search platform that help customers to buy vehicles according to their needs. This dataset contains the following columns:
- year
- selling price
- showroom price
- fuel type
- seller type
- transmission
- number of previous owners
The dataset for this problem is taken from: https://www.kaggle.com/nehalbirla/vehicle-dataset-from-cardekho
We will create a model with the following steps:
- Download and explore the dataset
- Prepare the dataset for training
- Create a linear regression model
- Train the model to fit the data
- Make predictions using the trained model
In [1]:
# Import libraries import torch import torchvision import torch.nn as nn import pandas as pd import matplotlib.pyplot as plt import torch.nn.functional as F import seaborn as sns from torchvision.datasets.utils import download_url from torch.utils.data import DataLoader, TensorDataset, random_split
/usr/local/lib/python3.6/dist-packages/statsmodels/tools/_testing.py:19: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead. import pandas.util.testing as tm
Step 1: Download and explore the data
Let us begin by downloading the data.
In [5]:
# Import .csv file from google.colab import files uploaded = files.upload()
Upload widget is only available when the cell has been executed in the current browser session. Please rerun this cell to enable.
Saving datasets_33080_43333_car_data.csv to datasets_33080_43333_car_data.csv
In [6]:
# Create Pandas Dataframe import io dataframe = pd.read_csv(io.BytesIO(uploaded['datasets_33080_43333_car_data.csv'])) dataframe = dataframe.drop(['Car_Name', 'Kms_Driven'], axis=1) # Exclude columns that we won't use dataframe
Out[6]:
Year | Selling_Price | Present_Price | Fuel_Type | Seller_Type | Transmission | Owner | |
---|---|---|---|---|---|---|---|
0 | 2014 | 3.35 | 5.59 | Petrol | Dealer | Manual | 0 |
1 | 2013 | 4.75 | 9.54 | Diesel | Dealer | Manual | 0 |
2 | 2017 | 7.25 | 9.85 | Petrol | Dealer | Manual | 0 |
3 | 2011 | 2.85 | 4.15 | Petrol | Dealer | Manual | 0 |
4 | 2014 | 4.60 | 6.87 | Diesel | Dealer | Manual | 0 |
… | … | … | … | … | … | … | … |
296 | 2016 | 9.50 | 11.60 | Diesel | Dealer | Manual | 0 |
297 | 2015 | 4.00 | 5.90 | Petrol | Dealer | Manual | 0 |
298 | 2009 | 3.35 | 11.00 | Petrol | Dealer | Manual | 0 |
299 | 2017 | 11.50 | 12.50 | Diesel | Dealer | Manual | 0 |
300 | 2016 | 5.30 | 5.90 | Petrol | Dealer | Manual | 0 |
301 rows × 7 columns
In [0]:
num_rows = len(dataframe) input_cols = ['Year', 'Selling_Price', 'Fuel_Type', 'Seller_Type', 'Transmission', 'Owner'] categorical_cols = ['Year', 'Fuel_Type', 'Seller_Type', 'Transmission', 'Owner'] output_cols = ['Present_Price']
In [8]:
print('Minimun Charge: {}'.format(dataframe.Present_Price.min())) print('Average Charge: {}'.format(dataframe.Present_Price.mean())) print('Maximun Charge: {}'.format(dataframe.Present_Price.max()))
Minimun Charge: 0.32 Average Charge: 7.628471760797344 Maximun Charge: 92.6
In [9]:
plt.title("Distribution of Showroom Price") sns.distplot(dataframe.Present_Price, kde=False);
Step 2: Prepare the dataset for training
In [0]:
def dataframe_to_arrays(dataframe): # Make a copy of the original dataframe dataframe1 = dataframe.copy(deep=True) # Convert non-numeric categorical columns to numbers for col in categorical_cols: dataframe1[col] = dataframe1[col].astype('category').cat.codes # Extract input & outupts as numpy arrays inputs_array = dataframe1[input_cols].to_numpy() targets_array = dataframe1[output_cols].to_numpy() return inputs_array, targets_array
In [11]:
inputs_array, targets_array = dataframe_to_arrays(dataframe) # Create arrays for inputs and targets inputs_array, targets_array
Out[11]:
(array([[11. , 3.35, 2. , 0. , 1. , 0. ], [10. , 4.75, 1. , 0. , 1. , 0. ], [14. , 7.25, 2. , 0. , 1. , 0. ], ..., [ 6. , 3.35, 2. , 0. , 1. , 0. ], [14. , 11.5 , 1. , 0. , 1. , 0. ], [13. , 5.3 , 2. , 0. , 1. , 0. ]]), array([[ 5.59 ], [ 9.54 ], [ 9.85 ], [ 4.15 ], [ 6.87 ], [ 9.83 ], [ 8.12 ], [ 8.61 ], [ 8.89 ], [ 8.92 ], [ 3.6 ], [10.38 ], [ 9.94 ], [ 7.71 ], [ 7.21 ], [10.79 ], [10.79 ], [10.79 ], [ 5.09 ], [ 7.98 ], [ 3.95 ], [ 5.71 ], [ 8.01 ], [ 3.46 ], [ 4.41 ], [ 4.99 ], [ 5.87 ], [ 6.49 ], [ 3.95 ], [10.38 ], [ 5.98 ], [ 4.89 ], [ 7.49 ], [ 9.95 ], [ 8.06 ], [ 7.74 ], [ 7.2 ], [ 2.28 ], [ 3.76 ], [ 7.98 ], [ 7.87 ], [ 3.98 ], [ 7.15 ], [ 8.06 ], [ 2.69 ], [12.04 ], [ 4.89 ], [ 4.15 ], [ 7.71 ], [ 9.29 ], [30.61 ], [30.61 ], [19.77 ], [30.61 ], [10.21 ], [15.04 ], [ 7.27 ], [18.54 ], [ 6.8 ], [35.96 ], [18.61 ], [ 7.7 ], [35.96 ], [35.96 ], [36.23 ], [ 6.95 ], [23.15 ], [20.45 ], [13.74 ], [20.91 ], [ 6.76 ], [12.48 ], [18.61 ], [ 5.71 ], [ 8.93 ], [ 6.8 ], [14.68 ], [12.35 ], [22.83 ], [30.61 ], [14.89 ], [ 7.85 ], [25.39 ], [13.46 ], [13.46 ], [23.73 ], [92.6 ], [13.74 ], [ 6.05 ], [ 6.76 ], [18.61 ], [16.09 ], [13.7 ], [30.61 ], [22.78 ], [18.61 ], [25.39 ], [18.64 ], [18.61 ], [20.45 ], [ 1.9 ], [ 1.82 ], [ 1.78 ], [ 1.6 ], [ 1.47 ], [ 2.37 ], [ 3.45 ], [ 1.5 ], [ 1.5 ], [ 1.47 ], [ 1.78 ], [ 1.5 ], [ 2.4 ], [ 1.4 ], [ 1.47 ], [ 1.47 ], [ 1.47 ], [ 1.9 ], [ 1.47 ], [ 1.9 ], [ 1.26 ], [ 1.5 ], [ 1.17 ], [ 1.47 ], [ 1.75 ], [ 1.75 ], [ 0.95 ], [ 0.8 ], [ 0.87 ], [ 0.84 ], [ 0.87 ], [ 0.82 ], [ 0.95 ], [ 0.95 ], [ 0.81 ], [ 0.74 ], [ 1.2 ], [ 0.787], [ 0.87 ], [ 0.95 ], [ 1.2 ], [ 0.8 ], [ 0.84 ], [ 0.84 ], [ 0.99 ], [ 0.81 ], [ 0.787], [ 0.84 ], [ 0.94 ], [ 0.94 ], [ 0.826], [ 0.55 ], [ 0.99 ], [ 0.99 ], [ 0.88 ], [ 0.51 ], [ 0.52 ], [ 0.84 ], [ 0.54 ], [ 0.51 ], [ 0.95 ], [ 0.826], [ 0.99 ], [ 0.95 ], [ 0.54 ], [ 0.54 ], [ 0.55 ], [ 0.81 ], [ 0.73 ], [ 0.54 ], [ 0.83 ], [ 0.55 ], [ 0.64 ], [ 0.51 ], [ 0.72 ], [ 0.787], [ 1.05 ], [ 0.57 ], [ 0.52 ], [ 1.05 ], [ 0.51 ], [ 0.48 ], [ 0.58 ], [ 0.47 ], [ 0.75 ], [ 0.58 ], [ 0.52 ], [ 0.51 ], [ 0.57 ], [ 0.57 ], [ 0.75 ], [ 0.57 ], [ 0.75 ], [ 0.65 ], [ 0.787], [ 0.32 ], [ 0.52 ], [ 0.51 ], [ 0.57 ], [ 0.58 ], [ 0.75 ], [ 6.79 ], [ 5.7 ], [ 4.6 ], [ 4.43 ], [ 5.7 ], [ 7.13 ], [ 5.7 ], [ 8.1 ], [ 5.7 ], [ 4.6 ], [14.79 ], [13.6 ], [ 6.79 ], [ 5.7 ], [ 9.4 ], [ 4.43 ], [ 4.43 ], [ 9.4 ], [ 9.4 ], [ 4.43 ], [ 6.79 ], [ 7.6 ], [ 9.4 ], [ 9.4 ], [ 4.6 ], [ 5.7 ], [ 4.43 ], [ 9.4 ], [ 6.79 ], [ 9.4 ], [ 9.4 ], [14.79 ], [ 5.7 ], [ 5.7 ], [ 9.4 ], [ 4.43 ], [13.6 ], [ 9.4 ], [ 4.43 ], [ 9.4 ], [ 7.13 ], [ 7.13 ], [ 7.6 ], [ 9.4 ], [ 9.4 ], [ 6.79 ], [ 9.4 ], [ 4.6 ], [ 7.6 ], [13.6 ], [ 9.9 ], [ 6.82 ], [ 9.9 ], [ 9.9 ], [ 5.35 ], [13.6 ], [13.6 ], [13.6 ], [ 7. ], [13.6 ], [ 5.97 ], [ 5.8 ], [ 7.7 ], [ 7. ], [ 8.7 ], [ 7. ], [ 9.4 ], [ 5.8 ], [10. ], [10. ], [10. ], [10. ], [ 7.5 ], [ 6.8 ], [13.6 ], [13.6 ], [13.6 ], [ 8.4 ], [13.6 ], [ 5.9 ], [ 7.6 ], [14. ], [11.8 ], [ 5.9 ], [ 8.5 ], [ 7.9 ], [ 7.5 ], [13.6 ], [13.6 ], [ 6.4 ], [ 6.1 ], [ 8.4 ], [ 9.9 ], [ 6.8 ], [13.09 ], [11.6 ], [ 5.9 ], [11. ], [12.5 ], [ 5.9 ]]))
In [0]:
# Create PyTorch Tensors from Numpy arrays inputs = torch.from_numpy(inputs_array).float() targets = torch.from_numpy(targets_array).float()
Next, we need to create PyTorch datasets & data loaders for training & validation. We’ll start by creating a TensorDataset
.
In [0]:
dataset = TensorDataset(inputs, targets)
In [0]:
val_percent = 0.1 val_size = int(num_rows * val_percent) train_size = num_rows - val_size train_ds, val_ds = random_split(dataset, [train_size, val_size]) # split dataset into 2 parts of the desired length
Finally, we can create data loaders for training & validation.
In [0]:
batch_size = 50 train_loader = DataLoader(train_ds, batch_size, shuffle=True) val_loader = DataLoader(val_ds, batch_size)
In [16]:
for xb, yb in train_loader: print("inputs:", xb) print("targets:", yb) break
inputs: tensor([[ 5.0000, 0.2000, 2.0000, 1.0000, 1.0000, 0.0000], [12.0000, 6.8500, 1.0000, 0.0000, 1.0000, 0.0000], [ 3.0000, 1.0500, 2.0000, 0.0000, 1.0000, 0.0000], [ 6.0000, 0.9000, 2.0000, 1.0000, 1.0000, 0.0000], [ 9.0000, 4.9500, 1.0000, 0.0000, 1.0000, 0.0000], [13.0000, 1.0500, 2.0000, 1.0000, 1.0000, 0.0000], [ 6.0000, 3.3500, 2.0000, 0.0000, 1.0000, 0.0000], [11.0000, 2.5500, 2.0000, 0.0000, 1.0000, 0.0000], [12.0000, 0.4000, 2.0000, 1.0000, 1.0000, 0.0000], [11.0000, 2.5000, 2.0000, 0.0000, 1.0000, 0.0000], [13.0000, 20.7500, 1.0000, 0.0000, 0.0000, 0.0000], [ 8.0000, 3.0000, 2.0000, 0.0000, 1.0000, 0.0000], [ 5.0000, 4.0000, 2.0000, 0.0000, 0.0000, 0.0000], [13.0000, 0.7500, 2.0000, 1.0000, 1.0000, 0.0000], [13.0000, 6.6000, 2.0000, 0.0000, 1.0000, 0.0000], [11.0000, 3.7500, 2.0000, 0.0000, 1.0000, 0.0000], [11.0000, 3.3500, 2.0000, 0.0000, 1.0000, 0.0000], [ 9.0000, 14.5000, 1.0000, 0.0000, 0.0000, 0.0000], [13.0000, 6.0000, 2.0000, 0.0000, 1.0000, 0.0000], [10.0000, 1.2500, 2.0000, 1.0000, 1.0000, 0.0000], [13.0000, 0.4500, 2.0000, 1.0000, 0.0000, 0.0000], [12.0000, 4.6500, 2.0000, 0.0000, 1.0000, 0.0000], [12.0000, 23.5000, 1.0000, 0.0000, 0.0000, 0.0000], [ 2.0000, 2.7500, 2.0000, 1.0000, 1.0000, 0.0000], [ 3.0000, 2.1000, 2.0000, 0.0000, 1.0000, 0.0000], [13.0000, 14.7300, 1.0000, 0.0000, 1.0000, 0.0000], [11.0000, 3.6500, 2.0000, 0.0000, 1.0000, 0.0000], [10.0000, 0.2700, 2.0000, 1.0000, 1.0000, 0.0000], [10.0000, 0.6500, 2.0000, 1.0000, 1.0000, 0.0000], [11.0000, 8.2500, 1.0000, 0.0000, 1.0000, 0.0000], [13.0000, 0.6000, 2.0000, 1.0000, 1.0000, 0.0000], [ 9.0000, 2.0000, 2.0000, 0.0000, 1.0000, 0.0000], [11.0000, 3.9500, 1.0000, 0.0000, 1.0000, 0.0000], [ 7.0000, 35.0000, 1.0000, 0.0000, 1.0000, 0.0000], [13.0000, 6.4000, 2.0000, 0.0000, 1.0000, 0.0000], [10.0000, 2.6500, 2.0000, 0.0000, 1.0000, 0.0000], [10.0000, 6.9500, 2.0000, 0.0000, 1.0000, 0.0000], [10.0000, 0.4200, 2.0000, 1.0000, 1.0000, 0.0000], [10.0000, 5.1100, 2.0000, 0.0000, 0.0000, 0.0000], [ 5.0000, 0.2000, 2.0000, 1.0000, 1.0000, 0.0000], [14.0000, 23.0000, 1.0000, 0.0000, 0.0000, 0.0000], [11.0000, 5.3000, 2.0000, 0.0000, 1.0000, 0.0000], [12.0000, 1.1000, 2.0000, 1.0000, 1.0000, 0.0000], [ 8.0000, 0.3500, 2.0000, 1.0000, 1.0000, 0.0000], [ 7.0000, 2.6500, 2.0000, 0.0000, 1.0000, 0.0000], [11.0000, 0.3500, 2.0000, 1.0000, 0.0000, 0.0000], [14.0000, 0.4000, 2.0000, 1.0000, 0.0000, 0.0000], [13.0000, 11.2500, 2.0000, 0.0000, 1.0000, 0.0000], [ 6.0000, 2.2500, 2.0000, 0.0000, 1.0000, 0.0000], [ 3.0000, 2.5000, 2.0000, 1.0000, 0.0000, 2.0000]]) targets: tensor([[ 0.7500], [10.3800], [ 4.1500], [ 1.7500], [ 9.4000], [ 1.2600], [11.0000], [ 3.9800], [ 0.5500], [ 3.4600], [25.3900], [ 4.9900], [22.7800], [ 0.8000], [ 7.7000], [ 6.8000], [ 5.5900], [30.6100], [ 8.4000], [ 1.5000], [ 0.5400], [ 7.2000], [35.9600], [10.2100], [ 7.6000], [14.8900], [ 7.0000], [ 0.4700], [ 0.7870], [14.0000], [ 0.8700], [ 4.4300], [ 6.7600], [92.6000], [ 8.4000], [ 4.8900], [18.6100], [ 0.7300], [ 9.4000], [ 0.7870], [25.3900], [ 6.8000], [ 1.4700], [ 1.0500], [ 7.9800], [ 0.5200], [ 0.5100], [13.6000], [ 7.2100], [23.7300]])
Step 3: Create a Linear Regression Model
In [0]:
input_size = len(input_cols) output_size = len(output_cols)
In [0]:
class PriceModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(input_size, output_size) def forward(self, xb): out = self.linear(xb) return out def training_step(self, batch): inputs, targets = batch # Generate predictions out = self(inputs) # Calculate loss loss = F.l1_loss(out, targets) return loss def validation_step(self, batch): inputs, targets = batch # Generate predictions out = self(inputs) # Calculate loss loss = F.l1_loss(out, targets) return {'val_loss': loss.detach()} def validation_epoch_end(self, outputs): batch_losses = [x['val_loss'] for x in outputs] epoch_loss = torch.stack(batch_losses).mean() return {'val_loss': epoch_loss.item()} def epoch_end(self, epoch, result, num_epochs): # Print result every 20th epoch if (epoch+1) % 20 == 0 or epoch == num_epochs-1: print("Epoch [{}], val_loss: {:.4f}".format(epoch+1, result['val_loss']))
In [0]:
model = PriceModel()
Let’s check out the weights and biases of the model using model.parameters
.
In [20]:
list(model.parameters())
Out[20]:
[Parameter containing: tensor([[-0.3935, -0.3412, -0.0166, -0.0592, 0.0338, -0.1588]], requires_grad=True), Parameter containing: tensor([-0.3614], requires_grad=True)]
Step 4: Train the model to fit the data
In [0]:
def evaluate(model, val_loader): outputs = [model.validation_step(batch) for batch in val_loader] return model.validation_epoch_end(outputs) def fit(epochs, lr, model, train_loader, val_loader, opt_func=torch.optim.SGD): history = [] optimizer = opt_func(model.parameters(), lr) for epoch in range(epochs): # Training Phase for batch in train_loader: loss = model.training_step(batch) loss.backward() optimizer.step() optimizer.zero_grad() # Validation phase result = evaluate(model, val_loader) model.epoch_end(epoch, result, epochs) history.append(result) return history
In [22]:
result = evaluate(model, val_loader) # Use the evaluate function print(result)
{'val_loss': 13.508749961853027}
In [23]:
epochs = 100 lr = 1e-2 history1 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.6579 Epoch [40], val_loss: 1.4193 Epoch [60], val_loss: 1.4849 Epoch [80], val_loss: 1.3621 Epoch [100], val_loss: 1.4464
In [24]:
epochs = 100 lr = 1e-3 history2 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.3564 Epoch [40], val_loss: 1.3603 Epoch [60], val_loss: 1.3568 Epoch [80], val_loss: 1.3567 Epoch [100], val_loss: 1.3553
In [25]:
epochs = 100 lr = 1e-4 history3 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.3552 Epoch [40], val_loss: 1.3548 Epoch [60], val_loss: 1.3551 Epoch [80], val_loss: 1.3549 Epoch [100], val_loss: 1.3550
In [26]:
epochs = 100 lr = 1e-5 history4 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.3550 Epoch [40], val_loss: 1.3550 Epoch [60], val_loss: 1.3549 Epoch [80], val_loss: 1.3550 Epoch [100], val_loss: 1.3550
In [27]:
epochs = 100 lr = 1e-6 history5 = fit(epochs, lr, model, train_loader, val_loader)
Epoch [20], val_loss: 1.3550 Epoch [40], val_loss: 1.3550 Epoch [60], val_loss: 1.3550 Epoch [80], val_loss: 1.3550 Epoch [100], val_loss: 1.3550
In [28]:
val_loss = [result] + history1 + history2 + history3 + history4 + history5 print(val_loss) val_loss_list = [vl['val_loss'] for vl in val_loss] plt.plot(val_loss_list, '-x') plt.xlabel('epochs') plt.ylabel('losses')
[{'val_loss': 13.508749961853027}, {'val_loss': 6.118033409118652}, {'val_loss': 4.425381660461426}, {'val_loss': 3.5121304988861084}, {'val_loss': 3.141930103302002}, {'val_loss': 2.6487579345703125}, {'val_loss': 2.2823987007141113}, {'val_loss': 2.0026051998138428}, {'val_loss': 1.78737473487854}, {'val_loss': 1.600001573562622}, {'val_loss': 1.4982473850250244}, {'val_loss': 1.4682856798171997}, {'val_loss': 1.4618252515792847}, {'val_loss': 1.4482910633087158}, {'val_loss': 1.5700268745422363}, {'val_loss': 1.449874758720398}, {'val_loss': 1.4577511548995972}, {'val_loss': 1.4320111274719238}, {'val_loss': 1.4231066703796387}, {'val_loss': 1.4389463663101196}, {'val_loss': 1.65794038772583}, {'val_loss': 1.4212898015975952}, {'val_loss': 1.4200899600982666}, {'val_loss': 1.435904860496521}, {'val_loss': 1.5388320684432983}, {'val_loss': 1.4144498109817505}, {'val_loss': 1.4971991777420044}, {'val_loss': 1.5992761850357056}, {'val_loss': 1.4277279376983643}, {'val_loss': 1.5031235218048096}, {'val_loss': 1.4794392585754395}, {'val_loss': 1.4153716564178467}, {'val_loss': 1.4078534841537476}, {'val_loss': 1.510008454322815}, {'val_loss': 1.4809588193893433}, {'val_loss': 1.4999046325683594}, {'val_loss': 1.3946417570114136}, {'val_loss': 1.6055680513381958}, {'val_loss': 1.4423900842666626}, {'val_loss': 1.4460539817810059}, {'val_loss': 1.419303297996521}, {'val_loss': 1.3937510251998901}, {'val_loss': 1.4011147022247314}, {'val_loss': 1.4060081243515015}, {'val_loss': 1.3984336853027344}, {'val_loss': 1.4633814096450806}, {'val_loss': 1.3751095533370972}, {'val_loss': 1.4318023920059204}, {'val_loss': 1.6094945669174194}, {'val_loss': 1.3969550132751465}, {'val_loss': 1.7303458452224731}, {'val_loss': 1.3911798000335693}, {'val_loss': 1.3729671239852905}, {'val_loss': 1.54682457447052}, {'val_loss': 1.3787986040115356}, {'val_loss': 1.649728536605835}, {'val_loss': 1.3683972358703613}, {'val_loss': 1.367186188697815}, {'val_loss': 1.4891501665115356}, {'val_loss': 1.3684563636779785}, {'val_loss': 1.4849108457565308}, {'val_loss': 1.37216055393219}, {'val_loss': 1.3770359754562378}, {'val_loss': 1.3935377597808838}, {'val_loss': 1.3642178773880005}, {'val_loss': 1.4782458543777466}, {'val_loss': 1.4224978685379028}, {'val_loss': 1.4049519300460815}, {'val_loss': 1.4761677980422974}, {'val_loss': 1.5544482469558716}, {'val_loss': 1.3702653646469116}, {'val_loss': 1.4356273412704468}, {'val_loss': 1.400244951248169}, {'val_loss': 1.377777338027954}, {'val_loss': 1.448345422744751}, {'val_loss': 1.3617547750473022}, {'val_loss': 1.6235082149505615}, {'val_loss': 1.4901325702667236}, {'val_loss': 1.542846441268921}, {'val_loss': 1.360831379890442}, {'val_loss': 1.3620823621749878}, {'val_loss': 1.3667073249816895}, {'val_loss': 1.3608430624008179}, {'val_loss': 1.5464335680007935}, {'val_loss': 1.433320164680481}, {'val_loss': 1.3586939573287964}, {'val_loss': 1.3673779964447021}, {'val_loss': 1.3756794929504395}, {'val_loss': 1.4192850589752197}, {'val_loss': 1.372723937034607}, {'val_loss': 1.4133044481277466}, {'val_loss': 1.3854647874832153}, {'val_loss': 1.364600658416748}, {'val_loss': 1.3931787014007568}, {'val_loss': 1.4342844486236572}, {'val_loss': 1.3628737926483154}, {'val_loss': 1.5067977905273438}, {'val_loss': 1.3583155870437622}, {'val_loss': 1.356598138809204}, {'val_loss': 1.37332022190094}, {'val_loss': 1.4464164972305298}, {'val_loss': 1.3635132312774658}, {'val_loss': 1.3564374446868896}, {'val_loss': 1.3569012880325317}, {'val_loss': 1.3573507070541382}, {'val_loss': 1.3600316047668457}, {'val_loss': 1.3589403629302979}, {'val_loss': 1.3563119173049927}, {'val_loss': 1.3571209907531738}, {'val_loss': 1.356133222579956}, {'val_loss': 1.356939435005188}, {'val_loss': 1.359302282333374}, {'val_loss': 1.3567982912063599}, {'val_loss': 1.357097864151001}, {'val_loss': 1.3570212125778198}, {'val_loss': 1.3584320545196533}, {'val_loss': 1.3569546937942505}, {'val_loss': 1.357675313949585}, {'val_loss': 1.3573272228240967}, {'val_loss': 1.3567001819610596}, {'val_loss': 1.356406807899475}, {'val_loss': 1.3568202257156372}, {'val_loss': 1.360009789466858}, {'val_loss': 1.3571254014968872}, {'val_loss': 1.357373595237732}, {'val_loss': 1.3569358587265015}, {'val_loss': 1.3566490411758423}, {'val_loss': 1.358595609664917}, {'val_loss': 1.356493353843689}, {'val_loss': 1.3569475412368774}, {'val_loss': 1.3568549156188965}, {'val_loss': 1.3565493822097778}, {'val_loss': 1.3592419624328613}, {'val_loss': 1.3565335273742676}, {'val_loss': 1.3574392795562744}, {'val_loss': 1.3569074869155884}, {'val_loss': 1.3572746515274048}, {'val_loss': 1.3571385145187378}, {'val_loss': 1.357692003250122}, {'val_loss': 1.3599610328674316}, {'val_loss': 1.3603343963623047}, {'val_loss': 1.3565924167633057}, {'val_loss': 1.3568990230560303}, {'val_loss': 1.3564502000808716}, {'val_loss': 1.3568381071090698}, {'val_loss': 1.3563752174377441}, {'val_loss': 1.3565641641616821}, {'val_loss': 1.3578929901123047}, {'val_loss': 1.3564704656600952}, {'val_loss': 1.3579374551773071}, {'val_loss': 1.3567204475402832}, {'val_loss': 1.358005404472351}, {'val_loss': 1.3587855100631714}, {'val_loss': 1.3610790967941284}, {'val_loss': 1.3572267293930054}, {'val_loss': 1.3567177057266235}, {'val_loss': 1.3564910888671875}, {'val_loss': 1.3565599918365479}, {'val_loss': 1.3571439981460571}, {'val_loss': 1.3573533296585083}, {'val_loss': 1.3567684888839722}, {'val_loss': 1.3597580194473267}, {'val_loss': 1.3570773601531982}, {'val_loss': 1.3559664487838745}, {'val_loss': 1.3565402030944824}, {'val_loss': 1.356568694114685}, {'val_loss': 1.3579165935516357}, {'val_loss': 1.3558008670806885}, {'val_loss': 1.3559842109680176}, {'val_loss': 1.3560093641281128}, {'val_loss': 1.355743169784546}, {'val_loss': 1.3560271263122559}, {'val_loss': 1.3596975803375244}, {'val_loss': 1.3558939695358276}, {'val_loss': 1.3556714057922363}, {'val_loss': 1.3555657863616943}, {'val_loss': 1.3557069301605225}, {'val_loss': 1.3573758602142334}, {'val_loss': 1.3586246967315674}, {'val_loss': 1.3611829280853271}, {'val_loss': 1.3566774129867554}, {'val_loss': 1.3559454679489136}, {'val_loss': 1.357486367225647}, {'val_loss': 1.3554534912109375}, {'val_loss': 1.3553087711334229}, {'val_loss': 1.3563026189804077}, {'val_loss': 1.3549476861953735}, {'val_loss': 1.3553135395050049}, {'val_loss': 1.3554954528808594}, {'val_loss': 1.3618000745773315}, {'val_loss': 1.3565247058868408}, {'val_loss': 1.3572131395339966}, {'val_loss': 1.3555357456207275}, {'val_loss': 1.3553812503814697}, {'val_loss': 1.3570847511291504}, {'val_loss': 1.3589755296707153}, {'val_loss': 1.3560833930969238}, {'val_loss': 1.3556828498840332}, {'val_loss': 1.3555214405059814}, {'val_loss': 1.3567246198654175}, {'val_loss': 1.3553409576416016}, {'val_loss': 1.3552978038787842}, {'val_loss': 1.355290174484253}, {'val_loss': 1.3553277254104614}, {'val_loss': 1.355329155921936}, {'val_loss': 1.3553335666656494}, {'val_loss': 1.3554489612579346}, {'val_loss': 1.3553740978240967}, {'val_loss': 1.3554728031158447}, {'val_loss': 1.3555446863174438}, {'val_loss': 1.355400800704956}, {'val_loss': 1.35552179813385}, {'val_loss': 1.3552525043487549}, {'val_loss': 1.3552584648132324}, {'val_loss': 1.3552162647247314}, {'val_loss': 1.3552507162094116}, {'val_loss': 1.35527765750885}, {'val_loss': 1.3552875518798828}, {'val_loss': 1.3552402257919312}, {'val_loss': 1.355261206626892}, {'val_loss': 1.3551791906356812}, {'val_loss': 1.3551557064056396}, {'val_loss': 1.355139970779419}, {'val_loss': 1.355112075805664}, {'val_loss': 1.355059266090393}, {'val_loss': 1.355006456375122}, {'val_loss': 1.3549349308013916}, {'val_loss': 1.3549612760543823}, {'val_loss': 1.354889154434204}, {'val_loss': 1.3548805713653564}, {'val_loss': 1.3548870086669922}, {'val_loss': 1.3550840616226196}, {'val_loss': 1.355032205581665}, {'val_loss': 1.355126976966858}, {'val_loss': 1.3553258180618286}, {'val_loss': 1.35527503490448}, {'val_loss': 1.3548399209976196}, {'val_loss': 1.354857087135315}, {'val_loss': 1.3548732995986938}, {'val_loss': 1.3548439741134644}, {'val_loss': 1.3548471927642822}, {'val_loss': 1.354846477508545}, {'val_loss': 1.3549150228500366}, {'val_loss': 1.3548537492752075}, {'val_loss': 1.3549411296844482}, {'val_loss': 1.3550673723220825}, {'val_loss': 1.3551135063171387}, {'val_loss': 1.3551299571990967}, {'val_loss': 1.3551310300827026}, {'val_loss': 1.355109453201294}, {'val_loss': 1.3552041053771973}, {'val_loss': 1.3551609516143799}, {'val_loss': 1.3550647497177124}, {'val_loss': 1.3550242185592651}, {'val_loss': 1.3549954891204834}, {'val_loss': 1.354806661605835}, {'val_loss': 1.3548814058303833}, {'val_loss': 1.3550723791122437}, {'val_loss': 1.355039119720459}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.3550713062286377}, {'val_loss': 1.3548802137374878}, {'val_loss': 1.3550053834915161}, {'val_loss': 1.355177640914917}, {'val_loss': 1.3551442623138428}, {'val_loss': 1.3550387620925903}, {'val_loss': 1.354960322380066}, {'val_loss': 1.3548640012741089}, {'val_loss': 1.3548303842544556}, {'val_loss': 1.3547861576080322}, {'val_loss': 1.354933500289917}, {'val_loss': 1.3548729419708252}, {'val_loss': 1.3549764156341553}, {'val_loss': 1.3550817966461182}, {'val_loss': 1.3551771640777588}, {'val_loss': 1.3549935817718506}, {'val_loss': 1.354945421218872}, {'val_loss': 1.3550140857696533}, {'val_loss': 1.355054259300232}, {'val_loss': 1.3547836542129517}, {'val_loss': 1.3548635244369507}, {'val_loss': 1.3549376726150513}, {'val_loss': 1.3549668788909912}, {'val_loss': 1.3548489809036255}, {'val_loss': 1.354942798614502}, {'val_loss': 1.3550866842269897}, {'val_loss': 1.3550900220870972}, {'val_loss': 1.3549708127975464}, {'val_loss': 1.3551058769226074}, {'val_loss': 1.3551361560821533}, {'val_loss': 1.355053424835205}, {'val_loss': 1.354856014251709}, {'val_loss': 1.3548083305358887}, {'val_loss': 1.3548500537872314}, {'val_loss': 1.354717493057251}, {'val_loss': 1.3547502756118774}, {'val_loss': 1.3547343015670776}, {'val_loss': 1.3547616004943848}, {'val_loss': 1.3547896146774292}, {'val_loss': 1.3548692464828491}, {'val_loss': 1.3549540042877197}, {'val_loss': 1.3549708127975464}, {'val_loss': 1.3549656867980957}, {'val_loss': 1.3549615144729614}, {'val_loss': 1.35495924949646}, {'val_loss': 1.3549537658691406}, {'val_loss': 1.3549537658691406}, {'val_loss': 1.354970097541809}, {'val_loss': 1.3549647331237793}, {'val_loss': 1.3549704551696777}, {'val_loss': 1.3549609184265137}, {'val_loss': 1.3549612760543823}, {'val_loss': 1.3549517393112183}, {'val_loss': 1.3549433946609497}, {'val_loss': 1.3549416065216064}, {'val_loss': 1.354941964149475}, {'val_loss': 1.354957938194275}, {'val_loss': 1.3549760580062866}, {'val_loss': 1.3549773693084717}, {'val_loss': 1.3549883365631104}, {'val_loss': 1.3549954891204834}, {'val_loss': 1.3550008535385132}, {'val_loss': 1.3549926280975342}, {'val_loss': 1.3549751043319702}, {'val_loss': 1.3549761772155762}, {'val_loss': 1.3549797534942627}, {'val_loss': 1.3549789190292358}, {'val_loss': 1.3549652099609375}, {'val_loss': 1.3549779653549194}, {'val_loss': 1.3549611568450928}, {'val_loss': 1.3549631834030151}, {'val_loss': 1.3549692630767822}, {'val_loss': 1.354960560798645}, {'val_loss': 1.3549777269363403}, {'val_loss': 1.3549540042877197}, {'val_loss': 1.3549739122390747}, {'val_loss': 1.3549824953079224}, {'val_loss': 1.354980707168579}, {'val_loss': 1.3549729585647583}, {'val_loss': 1.3549515008926392}, {'val_loss': 1.354953408241272}, {'val_loss': 1.3549593687057495}, {'val_loss': 1.354970932006836}, {'val_loss': 1.3549659252166748}, {'val_loss': 1.3549597263336182}, {'val_loss': 1.3549662828445435}, {'val_loss': 1.354949951171875}, {'val_loss': 1.3549219369888306}, {'val_loss': 1.354909062385559}, {'val_loss': 1.3549017906188965}, {'val_loss': 1.354910969734192}, {'val_loss': 1.3549164533615112}, {'val_loss': 1.3549116849899292}, {'val_loss': 1.3549190759658813}, {'val_loss': 1.354902982711792}, {'val_loss': 1.3549138307571411}, {'val_loss': 1.3549312353134155}, {'val_loss': 1.3549350500106812}, {'val_loss': 1.354937195777893}, {'val_loss': 1.3549233675003052}, {'val_loss': 1.3549250364303589}, {'val_loss': 1.3549437522888184}, {'val_loss': 1.354934573173523}, {'val_loss': 1.3549232482910156}, {'val_loss': 1.3549304008483887}, {'val_loss': 1.354939579963684}, {'val_loss': 1.3549563884735107}, {'val_loss': 1.354954481124878}, {'val_loss': 1.3549622297286987}, {'val_loss': 1.3549681901931763}, {'val_loss': 1.354967713356018}, {'val_loss': 1.3549511432647705}, {'val_loss': 1.3549466133117676}, {'val_loss': 1.3549611568450928}, {'val_loss': 1.3549678325653076}, {'val_loss': 1.3549621105194092}, {'val_loss': 1.3549638986587524}, {'val_loss': 1.3549736738204956}, {'val_loss': 1.3549846410751343}, {'val_loss': 1.3549665212631226}, {'val_loss': 1.3549554347991943}, {'val_loss': 1.3549529314041138}, {'val_loss': 1.35495126247406}, {'val_loss': 1.3549467325210571}, {'val_loss': 1.3549555540084839}, {'val_loss': 1.3549621105194092}, {'val_loss': 1.354956030845642}, {'val_loss': 1.3549562692642212}, {'val_loss': 1.3549613952636719}, {'val_loss': 1.3549690246582031}, {'val_loss': 1.3550021648406982}, {'val_loss': 1.354974389076233}, {'val_loss': 1.3549872636795044}, {'val_loss': 1.354970097541809}, {'val_loss': 1.3549933433532715}, {'val_loss': 1.355014443397522}, {'val_loss': 1.3550028800964355}, {'val_loss': 1.3550028800964355}, {'val_loss': 1.3550089597702026}, {'val_loss': 1.3550013303756714}, {'val_loss': 1.3550052642822266}, {'val_loss': 1.3550039529800415}, {'val_loss': 1.3550015687942505}, {'val_loss': 1.3550002574920654}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549987077713013}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.354996919631958}, {'val_loss': 1.354996919631958}, {'val_loss': 1.3549970388412476}, {'val_loss': 1.3549975156784058}, {'val_loss': 1.3549970388412476}, {'val_loss': 1.354995846748352}, {'val_loss': 1.3549954891204834}, {'val_loss': 1.3549965620040894}, {'val_loss': 1.3549964427947998}, {'val_loss': 1.3549935817718506}, {'val_loss': 1.3549926280975342}, {'val_loss': 1.3549935817718506}, {'val_loss': 1.354992389678955}, {'val_loss': 1.354993224143982}, {'val_loss': 1.3549939393997192}, {'val_loss': 1.354995846748352}, {'val_loss': 1.3549952507019043}, {'val_loss': 1.3549952507019043}, {'val_loss': 1.3549950122833252}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.3549987077713013}, {'val_loss': 1.3549998998641968}, {'val_loss': 1.3549989461898804}, {'val_loss': 1.3550000190734863}, {'val_loss': 1.3549987077713013}, {'val_loss': 1.3549976348876953}, {'val_loss': 1.3549994230270386}, {'val_loss': 1.3549989461898804}, {'val_loss': 1.3549981117248535}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.3549996614456177}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.3549973964691162}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.354998230934143}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549960851669312}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.354997992515564}, {'val_loss': 1.3549962043762207}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.354996919631958}, {'val_loss': 1.3549964427947998}, {'val_loss': 1.3549950122833252}, {'val_loss': 1.3549957275390625}, {'val_loss': 1.3549973964691162}, {'val_loss': 1.3549975156784058}, {'val_loss': 1.3549976348876953}, {'val_loss': 1.354999303817749}, {'val_loss': 1.3550000190734863}, {'val_loss': 1.3549994230270386}, {'val_loss': 1.3549977540969849}, {'val_loss': 1.3549987077713013}, {'val_loss': 1.3550001382827759}, {'val_loss': 1.3550000190734863}, {'val_loss': 1.3549977540969849}, {'val_loss': 1.3549970388412476}, {'val_loss': 1.3549978733062744}, {'val_loss': 1.3549968004226685}, {'val_loss': 1.354997158050537}, {'val_loss': 1.3549964427947998}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549983501434326}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549996614456177}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3550015687942505}, {'val_loss': 1.355000376701355}, {'val_loss': 1.3549975156784058}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549977540969849}, {'val_loss': 1.3549978733062744}, {'val_loss': 1.3549991846084595}, {'val_loss': 1.3549988269805908}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.3549978733062744}, {'val_loss': 1.3549984693527222}, {'val_loss': 1.3549995422363281}, {'val_loss': 1.354999303817749}, {'val_loss': 1.3549972772598267}, {'val_loss': 1.354998230934143}, {'val_loss': 1.354998230934143}, {'val_loss': 1.3549978733062744}, {'val_loss': 1.354997992515564}, {'val_loss': 1.3549975156784058}, {'val_loss': 1.35499906539917}, {'val_loss': 1.3549996614456177}]
Out[28]:
Text(0, 0.5, 'losses')
Step 5: Make predictions using the trained model
In [0]:
def predict_single(input, target, model): inputs = input.unsqueeze(0) predictions = model(inputs) prediction = predictions[0].detach() print("Input:", input) print("Target:", target) print("Prediction:", prediction)
In [30]:
input, target = val_ds[0] predict_single(input, target, model)
Input: tensor([12.0000, 8.2500, 1.0000, 0.0000, 1.0000, 0.0000]) Target: tensor([9.4000]) Prediction: tensor([10.9955])
In [31]:
input, target = val_ds[10] predict_single(input, target, model)
Input: tensor([12.0000, 5.2500, 2.0000, 0.0000, 1.0000, 0.0000]) Target: tensor([5.9000]) Prediction: tensor([7.5854])
In [32]:
input, target = val_ds[23] predict_single(input, target, model)
Input: tensor([11.0000, 4.9000, 1.0000, 0.0000, 1.0000, 0.0000]) Target: tensor([8.9300]) Prediction: tensor([6.5628])