Keras Tutorial - How to Use Google's Universal Sentence Encoder for Spam Classification

A tutorial for embedding Google's USE into your Keras models

Imagem de capa

The Tutorial Video

If you enjoyed this video or found it helpful in any way, I would love you forever if you passed me along a dollar or two to help fund my machine learning education and research! Every dollar helps me get a little closer and I’m forever grateful.

The Code

import tensorflow as tf
import tensorflow_hub as hub
import pandas as pd
from sklearn import preprocessing
import keras
import numpy as np

url = ""
embed = hub.Module(url)

data = pd.read_csv('spam.csv', encoding='latin-1')

y = list(data['v1'])
x = list(data['v2'])

le = preprocessing.LabelEncoder()

def encode(le, labels):
    enc = le.transform(labels)
    return keras.utils.to_categorical(enc)

def decode(le, one_hot):
    dec = np.argmax(one_hot, axis=1)
    return le.inverse_transform(dec)

test = encode(le, ['ham', 'spam', 'ham', 'ham'])

untest = decode(le, test)

x_enc = x
y_enc = encode(le, y)

x_train = np.asarray(x_enc[:5000])
y_train = np.asarray(y_enc[:5000])

x_test = np.asarray(x_enc[5000:])
y_test = np.asarray(y_enc[5000:])

from keras.layers import Input, Lambda, Dense
from keras.models import Model
import keras.backend as K

def UniversalEmbedding(x):
    return embed(tf.squeeze(tf.cast(x, tf.string)))

input_text = Input(shape=(1,), dtype=tf.string)
embedding = Lambda(UniversalEmbedding, output_shape=(512, ))(input_text)
dense = Dense(256, activation='relu')(embedding)
pred = Dense(2, activation='softmax')(dense)
model = Model(inputs=[input_text], outputs=pred)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

with tf.Session() as session:
    history =, y_train, epochs=1, batch_size=32)

with tf.Session() as session:
    predicts = model.predict(x_test, batch_size=32)

y_test = decode(le, y_test)
y_preds = decode(le, predicts)

from sklearn import metrics

metrics.confusion_matrix(y_test, y_preds)

print(metrics.classification_report(y_test, y_preds))