Federated Learning - Learning Across Devices
Federated Learning is a type of machine learning where multiple devices collaborate to train a machine learning model. This is done without sharing the raw data between devices, thus preserving user privacy. In federated learning, each device trains the model locally using its own data and sends the model weights to a central server. The server then aggregates the weights and sends an updated version of the model back to each device, which continues to train the model using its own data.
Here is an example of how to implement federated learning using TensorFlow Federated in Python:
import tensorflow as tf
import tensorflow_federated as tff
# Define a function that creates a Keras model
def create_keras_model():
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(10, input_shape=(784,), activation='softmax')
])
return model
# Load the MNIST dataset
mnist_train, mnist_test = tf.keras.datasets.mnist.load_data()
# Preprocess the data
def preprocess_mnist_data(images, labels):
images = images.reshape((-1, 784)).astype('float32') / 255.
labels = tf.keras.utils.to_categorical(labels)
return images, labels
mnist_train = preprocess_mnist_data(*mnist_train)
mnist_test = preprocess_mnist_data(*mnist_test)
# Define the Federated dataset
train_data = tff.simulation.datasets.mnist.get_federated_train_data(
client_batch_size=10, client_epochs=1)
# Define the Federated model
def create_federated_model():
return tff.learning.from_keras_model(
create_keras_model(),
input_spec=train_data.element_spec,
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy()])
# Define the Federated learning process
iterative_process = tff.learning.build_federated_averaging_process(
create_federated_model)
# Train the Federated model
state = iterative_process.initialize()
for i in range(10):
state, metrics = iterative_process.next(state, train_data)
print(f'Round {i+1}, metrics={metrics}')
# Evaluate the Federated model on the test set
federated_eval = tff.learning.build_federated_evaluation(create_federated_model())
test_metrics = federated_eval(state.model, [mnist_test])
print(f'Test metrics={test_metrics}')
In this code example, we are using the MNIST dataset to train a federated learning model. We define a function create_keras_model
that creates a simple Keras model with a single dense layer. We then load the MNIST dataset and preprocess the data. We define a Federated dataset using the get_federated_train_data
function from TensorFlow Federated, which splits the dataset into several small subsets and creates a tff.simulation.ClientData
object for each subset.
We define a function create_federated_model
that creates a Federated model from the Keras model using tff.learning.from_keras_model
. We then define the Federated learning process using tff.learning.build_federated_averaging_process
. We train the Federated model using a for loop that iterates over several rounds of training, and we evaluate the model on the test set using tff.learning.build_federated_evaluation
.
This example is a simplified version of federated learning, and in practice, there are many additional considerations for security and privacy when implementing a federated learning system.
Leave a Comment