3.3.4 Deep Learning Frameworks

Deep learning frameworks are software libraries that provide the building blocks for creating and training neural networks. They abstract low-level details (like tensor operations and gradient computation) so developers can focus on model design and experimentation. Figure 1 illustrates a simple feedforward neural network with input, hidden, and output layers – the fundamental concept behind many deep learning models. Frameworks represent such models as computation graphs of tensor operations and automatically handle gradient backpropagation. For example, TensorFlow (pre-v2) builds a static graph at model definition time, whereas PyTorch uses a dynamic graph evaluated on the fly. This distinction – static vs. dynamic computation graphs – affects flexibility and performance of the framework.

Figure 1: A simple feedforward neural network (ANN) with input, hidden, and output layers. Deep learning frameworks represent such architectures as computation graphs of tensor operations.

Deep learning frameworks also manage components like optimizers (e.g. SGD, Adam), loss functions (e.g. cross-entropy), and data pipelines. They leverage automatic differentiation (autograd) to compute gradients through complex networks. In practice, developers define a model, specify a loss, and use the framework’s training API to iterate over data, backpropagate errors, and update weights. For example, a generic training loop in pseudocode is:

for epoch in range(num_epochs):  
    for X_batch, Y_batch in data:  
        outputs = model(X_batch)            # forward pass  
        loss = loss_function(outputs, Y_batch)  
        loss.backward()                     # backward pass (compute gradients)  
        optimizer.step()                   # update weights  

This training pattern is common to all frameworks, though many now provide high-level tools to abstract away boilerplate. The choice of framework often depends on factors like performance, ease of use, and target deployment (discussed below).

Comparative Feature Analysis

Different frameworks excel in different aspects. Below is a summary comparison along key dimensions:

  • Ease of Use & Flexibility: PyTorch and Keras are widely praised for their intuitive, Pythonic APIs and quick prototyping. PyTorch’s eager execution model makes debugging and dynamic model construction straightforward. Keras (now integrated into TensorFlow) offers a very high-level interface ideal for beginners. TensorFlow originally used a more complex static-graph approach, but with TensorFlow 2.x and Keras integration it has become much more user-friendly.

  • Performance and Scalability: TensorFlow is highly optimized for large-scale applications, especially on distributed systems. It can compile computation graphs via XLA and scale across multiple GPUs or TPUs. JAX also uses XLA compilation to maximize speed on accelerators. PyTorch is competitive in performance, and while it historically focused on research, it now offers production-grade optimizations (TorchScript, JIT) and multi-GPU training. In practice, both TensorFlow and PyTorch achieve similar performance on many tasks, with TensorFlow having a slight edge in large-scale optimization and resource management.

  • Community and Ecosystem: TensorFlow, backed by Google since 2015, has a very large and mature community with extensive documentation, tutorials, and third-party tools. PyTorch, backed by Meta, has rapidly grown especially in academia and research. JAX’s community is smaller but growing, with libraries like Flax and Haiku supporting high-level model building. Keras (as part of TensorFlow) benefits from TensorFlow’s ecosystem, while PyTorch Lightning is an emerging framework that simplifies PyTorch workflows.

  • Extensibility: Most frameworks allow custom layers and operations. PyTorch and JAX, being more low-level, give full control at the cost of more boilerplate. TensorFlow (especially via tf.keras) and PyTorch Lightning provide higher-level abstractions. PyTorch Lightning, for instance, decouples engineering logic from research logic, enabling built-in support for multi-GPU/TPU training and many engineering “tricks” out-of-the-box.

  • Interoperability: ONNX (Open Neural Network Exchange) enables models to be ported between frameworks. It is an open standard that lets you train in one framework (e.g. PyTorch) and deploy in another (e.g. TensorFlow) without performance loss. Many frameworks provide built-in export to ONNX and integration with ONNX Runtime for deployment.

TensorFlow (with Keras)

TensorFlow is a comprehensive library by Google for deep learning and machine learning. It uses TensorBoard for visualization and supports complex architectures. Modern TensorFlow uses eager execution by default (similar to PyTorch) but still allows graph optimization via tf.function and XLA. Its high-level API, Keras, is fully integrated, enabling easy model building. TensorFlow excels at production deployment: it includes TensorFlow Serving for scalable model serving, TensorFlow Lite for mobile/edge, and TensorFlow.js for in-browser inference.

Example – Building and Training a Model in TensorFlow/Keras:

import tensorflow as tf

# Define a simple feedforward model using Keras
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy', 
              metrics=['accuracy'])
# Train on data (e.g., MNIST)
model.fit(train_images, train_labels, epochs=5)

This code builds a 3-layer neural network using TensorFlow’s Keras API. TensorFlow’s high-level interface makes model definition concise. It also automatically takes care of GPU acceleration if available. TensorFlow supports GPU/TPU execution: users can run code on NVIDIA GPUs or Google’s TPUs, and scale training using tf.distribute.Strategy. In fact, TensorFlow “excels in scaling across multiple GPUs and TPUs”. For edge deployment, TensorFlow Lite can convert models for mobile devices, and TensorFlow.js enables running models in web browsers (for example, deploying a TensorFlow model as a web app).

PyTorch (and PyTorch Lightning)

PyTorch is an open-source framework from Meta (Facebook), known for its dynamic computation graphs and Pythonic design. Models in PyTorch are defined using torch.nn.Module subclasses or nn.Sequential. PyTorch’s eager execution makes it easy to write and debug models. Here’s an example of the same feedforward network in PyTorch:

import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple model by subclassing nn.Module
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = x.view(-1, 784)        # flatten input
        x = self.relu(self.fc1(x))
        return torch.softmax(self.fc2(x), dim=1)

model = Net()
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

# Example training loop
for epoch in range(5):
    for X_batch, y_batch in train_loader:
        optimizer.zero_grad()
        preds = model(X_batch)
        loss = loss_fn(preds, y_batch)
        loss.backward()
        optimizer.step()

PyTorch code explicitly calls loss.backward() to compute gradients and optimizer.step() to update weights. PyTorch supports both single-GPU and multi-GPU training (via DataParallel and DistributedDataParallel). For example, distributing across GPUs is as simple as adding a wrapper or using PyTorch Lightning’s Trainer.

PyTorch Lightning is a higher-level framework built on PyTorch that removes boilerplate. It provides a LightningModule where you define training_step, and a Trainer class that handles loops, logging, checkpointing, and device management. For instance, a Lightning model might look like:

import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F

class LitNet(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(784, 10)
    def forward(self, x):
        return self.layer(x.view(-1, 784))
    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        loss = F.cross_entropy(preds, y)
        self.log('train_loss', loss)
        return loss
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

trainer = pl.Trainer(max_epochs=5, accelerator='gpu', devices=1)
trainer.fit(LitNet(), train_loader)

In this Lightning example, we never explicitly call backward() or step() – the Trainer does it automatically. Lightning “automates 40+ tricks” including multi-GPU training and 16-bit precision. In fact, Lightning lets you “use multiple GPUs/TPUs/HPUs etc… without code changes”, greatly simplifying distributed training. Overall, PyTorch (with Lightning) offers great flexibility and rapid prototyping, making it especially popular in research. Its community is very active, and many open-source projects build on it (e.g. Hugging Face Transformers).

For deployment, PyTorch provides TorchScript (to serialize models) and supports exporting models to ONNX. ONNX export is as simple as:

dummy_input = torch.randn(1, 1, 28, 28)
torch.onnx.export(model, dummy_input, "model.onnx", export_params=True, opset_version=13)

This produces an ONNX file that can run on the ONNX Runtime across different platforms.

JAX

JAX is a newer Google library for high-performance numerical computing. It extends NumPy with automatic differentiation and GPU/TPU acceleration. Its key features include:

  • Autodiff and JIT: JAX transparently computes gradients of almost any Python function with grad(), and compiles functions with jit() for speed on accelerators.

  • NumPy API: JAX’s API is almost identical to NumPy, easing adoption of existing codebases.

  • XLA Compilation: JAX uses XLA under the hood to fuse operations and run efficiently on GPUs/TPUs.

  • Functional Style: Models are defined in a purely functional way (no side-effects). Libraries like Flax or Haiku provide neural network layers on top of JAX.

Example – a simple JAX model:

import jax.numpy as jnp
from jax import grad, jit

# Define a loss function for linear regression
def loss_fn(w, x, y):
    preds = jnp.dot(x, w)            # simple linear model
    return jnp.mean((preds - y)**2)

# Gradient of the loss w.r.t parameters
grad_loss = jit(grad(loss_fn))
# Initialize parameters w
w = jnp.zeros((784, 10))
# Compute gradients on a batch
grads = grad_loss(w, X_batch, y_batch)

This snippet defines a mean-squared-error loss and computes its gradient using JAX’s grad. Notice that we use jit to compile the gradient function. JAX automatically handles differentiation and can run this on GPU or TPU without change to the code. Its performance on accelerators is excellent due to XLA. Because of these features, JAX is often favored for research that involves complex math or requires ultimate speed on TPU/GPU clusters. However, JAX’s ecosystem is still growing; it has fewer “batteries-included” tools than TensorFlow or PyTorch.

Keras

Keras is a high-level neural network API that was originally independent but is now tightly integrated into TensorFlow. It provides a user-friendly, modular way to build networks layer by layer. Because of its simplicity, Keras is an excellent choice for beginners and rapid prototyping. (The earlier TensorFlow example used tf.keras.Sequential.) Under the hood, tf.keras benefits from TensorFlow’s optimization and hardware support. Keras models can scale from CPU to GPU with minimal code changes. In practice, you now almost always use Keras within TensorFlow (i.e. tf.keras); standalone Keras (with other backends) is largely legacy.

ONNX (Interoperability)

The Open Neural Network Exchange (ONNX) is not a training framework but a model format that enables portability between frameworks. It’s an open standard supported by many companies. ONNX allows a model trained in one framework to be run in another. For example, you can train a model in PyTorch, export to ONNX, and then load it in TensorFlow or use ONNX Runtime for deployment. ONNX also provides a “model zoo” of pre-trained networks. This interoperability is crucial in real-world pipelines: teams may develop models in the most convenient framework and then deploy them in the fastest or most compatible runtime.

ONNX Runtime itself is highly optimized for inference across platforms, including cloud servers, edge devices, and mobile. Microsoft notes that ONNX Runtime “enables interoperability between different frameworks and streamlines the path from research to production”. For deployment on resource-constrained devices, ONNX Runtime Mobile and other backends offer accelerated inference. In summary, ONNX forms a bridge: it “allows models to be trained in one framework (like PyTorch) and then deployed in another (like TensorFlow) without losing performance”.

Accelerators and Distributed Training

All modern frameworks support GPU acceleration via NVIDIA’s CUDA. TensorFlow and PyTorch automatically utilize GPUs for tensor operations. TensorFlow also targets Google’s TPUs natively, and JAX was designed for TPU usage. For multi-GPU or multi-node training, TensorFlow offers tf.distribute strategies and third-party tools like Horovod, while PyTorch provides torch.distributed and leverages libraries like Horovod as well. PyTorch Lightning’s Trainer can transparently handle distributed setups. For instance, using Lightning one can simply specify multiple GPUs:

trainer = pl.Trainer(accelerator='gpu', devices=4)
trainer.fit(model, train_loader)

Lightning “automates … Multi-GPU support, TPU” among many other features. In practice, distributed training is essential for very large datasets or models. TensorFlow’s distributed API and PyTorch’s DDP allow scaling to clusters of GPUs; JAX similarly provides pmap for parallelism. Optimizations like mixed-precision (16-bit) training are also well-supported in all frameworks to further improve speed and memory efficiency.

Deployment: Edge, Mobile, and Cloud

Beyond training, deployment targets vary: on-cloud, on-premise, or on-device. TensorFlow offers TensorFlow Lite for mobile/embedded (Android, iOS) and TensorFlow.js for web. Keras models can be converted to these formats easily. PyTorch provides TorchScript and PyTorch Mobile to run on devices; models can be optimized with tools like Torch TensorRT for inference. ONNX Runtime has mobile and web backends as well.

In cloud environments, all these frameworks are supported by major platforms (AWS SageMaker, Google AI Platform, Azure ML, etc.) with managed GPU/TPU instances. They integrate with containerization (Docker) for scalable serving. For example, TensorFlow Serving can serve TensorFlow/Keras models via REST/gRPC; PyTorch models can be containerized or served with third-party servers. The ecosystem is rich with supporting tools (e.g. Kubeflow, MLflow) that work with these frameworks.

Real-World Applications

Deep learning frameworks power applications across industries. In healthcare, TensorFlow and PyTorch are used for medical image analysis, genomics, and drug discovery. For instance, Genentech uses PyTorch in cancer research and personalized medicine. In autonomous vehicles, companies like Toyota leverage PyTorch for complex vision processing in self-driving systems. In technology and internet services, Google uses TensorFlow for everything from speech recognition to photo search and translation. Financial institutions apply deep learning (often using TensorFlow/PyTorch) for fraud detection, risk modeling, and algorithmic trading. While specific cases vary, the pattern is clear: the choice of framework depends on needs – PyTorch is popular in research labs and for rapid prototyping, whereas TensorFlow often underlies large-scale production systems.

Across sectors, frameworks integrate with domain-specific libraries (e.g. Hugging Face for NLP, OpenCV for vision) and toolchains. For example, ONNX enables a model trained for healthcare on PyTorch to be deployed in a C++ inference engine on hospital equipment. In finance, low-latency inference might use ONNX Runtime for speed. On mobile, TensorFlow Lite runs models on smartphones for health monitoring or personal finance apps.

Conclusion

In summary, TensorFlow, PyTorch, JAX, Keras, and related tools each offer unique strengths. TensorFlow (with Keras) provides a full-stack solution from research to production, with excellent scaling and deployment options. PyTorch is lauded for its ease of use and has a vibrant research community; PyTorch Lightning extends it for production-scale training. JAX brings high-performance XLA compilation and a familiar NumPy interface for research. Keras (as part of TensorFlow) lowers the barrier to entry for beginners. ONNX bridges these frameworks to maximize flexibility in deployment.

When choosing a framework, consider the project’s needs: the desired balance of development speed versus deployment scale, hardware targets (GPU/TPU/edge), and existing ecosystem. Often, teams use multiple tools together (e.g. PyTorch for prototyping, then convert to ONNX for serving). All these frameworks continue to evolve rapidly – integration of new accelerators, optimization techniques, and interoperability is a major focus. For practitioners from beginners to experts, mastering one or more of these frameworks is essential for modern AI development.

References: 

[1] GeeksforGeeks: Most Popular Deep Learning Software In 2024

[2] OpenCV blog: PyTorch vs TensorFlow 2025

[3] JAX docs: JAX Features

[4] PyTorch Lightning docs; 

[5] ONNX docs and tutorials.

댓글

이 블로그의 인기 게시물

Expert Systems and Knowledge-Based AI (1960s–1980s)

4.1. Deep Learning Frameworks

Core Technologies of Artificial Intelligence Services part2