Training
In order to train your own harmonized model, we have made available a way to simply load the ClickMe training set, as well as the harmonization loss we have used in the paper.
Loading ClickMe training set¶
First, you will need to load the training dataset:
from harmonization.common import load_clickme_train
clickme_ds = load_clickme_train(batch_size = 128)
for images, heatmaps, labels in clickme_ds:
print(images.shape) # (128, 224, 224, 3)
print(heatmaps.shape) # (128, 224, 224, 1)
print(labels.shape) # (128, 1000)
Note that, if you already have the shards locally, you can also load the dataset using the load_clickme
function:
from harmonization.common import load_clickme
clickme_ds = load_clickme_train(shards_paths = ['dataset/train_clickme_0',
'dataset/train_clickme_1'
...
], batch_size = 128)
Using the Harmonization loss¶
Now that we know how to load the training set, we just need the harmonization loss:
def harmonizer_loss(model, images, tokens, labels, true_heatmaps,
cross_entropy = tf.keras.losses.CategoricalCrossentropy(),
lambda_weights=1e-5, lambda_harmonization=1.0):
...
To use the loss, simply call the function with your model, the images / labels and heatmaps for ClickMe:
from harmonization.training import harmonizer_loss
... # loading dataset
for images, heatmaps, labels in clickme_ds:
tokens = tf.ones(len(images)) # tokens are flags to indicate if the image have an associated heatmap
loss = harmonizer_loss(model, images, tokens, labels, heatmaps)
For example, if we decide to mix the ClickMe dataset with ImageNet, we may not have heatmaps for each images, in that case we can use the tokens
flag parameters to designate when an heatmaps is provided (1
means heatmaps provided).