Within the interest of managing reader expectations and stopping disappointment, we would love to start by stating that this post does not provide a totally satisfactory solution to the issue described within the title. We are going to propose and assess two possible schemes for auto-conversion of TensorFlow models to PyTorch — the primary based on the Open Neural Network Exchange (ONNX) format and libraries and the second using the Keras3 API. Nonetheless, as we are going to see, each comes with its own set of challenges and limitations. To the perfect of the authors’ knowledge, on the time of this writing, there are not any publicly available foolproof solutions to this problem.
Many because of Rom Maltser for his contributions to this post.
The Decline of TensorFlow
Over time, the sector of computer science has known its fair proportion of “religious wars” — heated, sometimes hostile, debates amongst programmers and engineers over the “best” tools, languages, and methodologies. Up until just a few years ago, the religious war between PyTorch and TensorFlow, two outstanding open-source deep learning frameworks, loomed large. Proponents of TensorFlow would highlight its fast graph-execution mode, while those within the PyTorch camp would emphasize its “Pythonic” nature and ease of use.
Nonetheless, nowadays, the quantity of activity in PyTorch far overshadows that of TensorFlow. That is evidenced by the variety of big-tech corporations which have embraced PyTorch over TensorFlow, by the variety of models per framework in HuggingFace’s models repository, and by the quantity of innovation and optimization in each framework. Simply put, TensorFlow is a shell of its former self. The war is over, with PyTorch the definitive winner. For a transient history of the Pytorch-TensorFlow wars and the explanations for TensorFlow’s downfall, see Pan Xinghan’s post: TensorFlow Is Dead. PyTorch Won.
Problem: What can we do with all of our legacy TensorFlow models?!!
In light of this latest reality, many organizations that when used TensorFlow have moved all of their latest AI/ML model development to PyTorch. But they’re faced with a difficult challenge with regards to their legacy code: What should they do with all the models which have already been built and deployed in TensorFlow?
Option 1: Do Nothing.
You is likely to be wondering why that is even an issue — the TensorFlow models work — let’s not touch them. While this can be a valid approach, there are numerous disadvantages that ought to be considered:
- Reduced maintenance: As TensorFlow continues to say no so will its maintenance. Inevitably, things will start to interrupt. For instance, there could also be problems with compatibility with newer Python packages or system libraries.
- Limited Ecosystem: AI/ML solutions typically involve multiple supporting software libraries and services that interface with our framework of selection, be it PyTorch or TensorFlow. Over time, we will expect to see lots of these discontinue their support for TensorFlow. Living proof: HuggingFace recently announced the deprecation of its support for TensorFlow.
- Limited Community: The AI/ML industry owes its fast pace of development, largely, to its community. The variety of open source projects, the variety of online tutorials, and the quantity of activity in dedicated support channels within the AI/ML space, is unparalleled. As TensorFlow declines, so will its community and it’s possible you’ll experience increasing difficulty getting the enable you to need. Pointless to say, the PyTorch community is flourishing.
- Opportunity Cost: The PyTorch ecosystem is prospering with constant innovations and optimizations. Recent years have seen the event of flash-attention kernels, support for the eight-bit floating-point data type, graph compilation, and plenty of other advancements which have demonstrated significant boosts to runtime performance and significant reductions in AI/ML costs. In the course of the same time period the feature offering in TensorFlow has remained mostly static. Sticking with TensorFlow means forgoing many opportunities for AI/ML cost optimization.
Option 2: Manually Convert TensorFlow Models to PyTorch
The second option is to rewrite legacy TensorFlow models in PyTorch. This might be the perfect option when it comes to its result, but for corporations which have built up technical debt over a few years, converting even a single model could possibly be a frightening task. Given the trouble required, it’s possible you’ll select to do that just for models which can be still under lively development (e.g., within the model training phase). Doing this for all the models which can be already deployed may prove prohibitive.
Option 3: Automate TensorFlow to PyTorch Conversion
The third option, and the approach we explore on this post, is to automate the conversion of legacy TensorFlow models to PyTorch. In this way, we hope to perform the good thing about model execution in PyTorch, but without the big effort of manually converting each.
To facilitate our discussion we are going to define a toy TensorFlow model and assess two proposals for converting it to PyTorch. As our runtime environment, we are going to use an Amazon EC2 g6e.xlarge with an NVIDIA L40S GPU, an AWS Deep Learning Ubuntu (22.04) AMI, and a Python environment that features the TensorFlow (2.20), PyTorch (2.9), torchvision (0.24.0), and transformers (4.55.4) libraries. Please note that the code blocks we are going to share are intended for demonstrative purposes. Please don’t interpret our use of any code, library, or platform as an endorsement of its use.
Model Conversion — Why is it Hard?
An AI model definition is comprised of two components: a model architecture and its trained weights. A model conversion solution must address each components. Conversion of the model weights is pretty straightforward; the weights are typically stored in a format that will be easily parsed into individual tensor arrays and reapplied within the framework of selection. In contrast, conversion of the model architecture presents a much greater challenge.
One approach could possibly be to create a mapping between the constructing blocks of the model in each of the frameworks. Nonetheless, there are numerous aspects that make this approach, for all intents and purposes, virtually intractable:
- API Overlap and Proliferation: Once you take note of the sheer variety of, often overlapping, TensorFlow APIs for constructing model components after which add the vast variety of API controls and arguments for every layer, you possibly can see how making a comprehensive, one-to-one mapping can quickly get ugly.
- Differing Implementation Approaches: On the implementation level, TensorFlow and PyTorch have fundamentally different approaches. Although normally hidden behind the top-level APIs, some assumptions require special user attention. For instance, while TensorFlow defaults to the “channels-last” (NHWC) format, PyTorch prefers “channels-first” (NCHW). This difference in how tensors are indexed and stored complicates the conversion of model operations, as every layer have to be checked/altered for proper dimension ordering.
Quite than attempt conversion on the API level, an alternate approach could possibly be to capture and convert an internal TensorFlow graph representation. Nonetheless, as anyone who has ever looked under the hood of TensorFlow will inform you, this too could get pretty nasty in a short time. TensorFlow’s internal graph representation is incredibly complex, often including a large number of low-level operations, control flow, and auxiliary nodes that do not need a direct equivalent in PyTorch (especially when you’re coping with older versions of TensorFlow). Just its comprehension seems beyond normal human ability, let alone its conversion to PyTorch.
Note that the identical challenges would make it difficult for a generative AI model to perform the conversion in a way that’s fully reliable.
Proposed Conversion Schemes
In light of those difficulties, we abandon our attempt at implementing our own model converter and as an alternative look to see what tools the AI/ML community has to supply. More specifically, we consider two different strategies for overcoming the challenges we described:
- Conversion Via a Unified Graph Representation: This solution assumes a typical standard for representing an AI/ML model definition and utilities for converting models to and from this standard. The answer we are going to explore uses the favored ONNX format.
- Conversion Based on a Standardized High-level API: On this solution we simplify the conversion task by limiting our model to an outlined set of high level abstract APIs with supported implementations in each of the AI/ML frameworks of interest. For this approach, we are going to use the Keras3 library.
In the subsequent sections we are going to assess these strategies on a toy TensorFlow model.
A Toy TensorFlow Model
Within the code block below we initialize and run a TensorFlow Vision Transformer (ViT) model from HuggingFace’s popular transformers library (version 4.55.4), TFViTForImageClassification. Note that in step with HuggingFace’s decision to deprecate support for TensorFlow, this class was faraway from recent releases of the library. The HuggingFace TensorFlow model depends on Keras 2 which we dutifully install via the tf-keras (2.20.1) package. We set the ViTConfig.hidden_act field to “gelu_new” for ONNX compatibility:
import tensorflow as tf
gpu = tf.config.list_physical_devices('GPU')[0]
tf.config.experimental.set_memory_growth(gpu, True)
from transformers import ViTConfig, TFViTForImageClassification
vit_config = ViTConfig(hidden_act="gelu_new", return_dict=False)
tf_model = TFViTForImageClassification(vit_config)
Model Conversion Using ONNX
The primary method we assess relies on Open Neural Network Exchange (ONNX), a community project that goals to define an open format for constructing AI/ML models to extend interoperability between AI/ML frameworks and reduce the dependence on any single one. Included within the ONNX API offering are utilities for converting models from common frameworks, including TensorFlow, to the ONNX format. There are also several public libraries for converting ONNX models to PyTorch. On this post we use the onnx2torch utility. Thus, model conversion from TensorFlow to PyTorch will be achieved by successively applying TensorFlow-to-ONNX conversion followed by ONNX-to-PyTorch conversion.
To evaluate this solution we install the onnx (1.19.1), tf2onnx (1.16.1), and onnx2torch (1.5.15 ) libraries. We apply the flag to stop an undesired downgrade of the protobuf library:
pip install --no-deps onnx tf2onnx onnx2torch
The conversion scheme appears within the code block below:
import tensorflow as tf
import torch
import tf2onnx, onnx2torch
BATCH_SIZE = 32
DEVICE = "cuda"
spec = (tf.TensorSpec((BATCH_SIZE, 3, 224, 224), tf.float32, name="input"),)
onnx_model, _ = tf2onnx.convert.from_keras(tf_model, input_signature=spec)
converted_model = onnx2torch.convert(onnx_model)
To be certain that that the resultant model is indeed a PyTorch module, we run the next assertion:
assert isinstance(converted_model, torch.nn.Module)
Allow us to now assess the standard and makeup of the resultant PyTorch model.
Numerical Precision
To confirm the validity of the converted model, we execute each the TensorFlow model and the converted model on the identical input and compare the outcomes:
import numpy as np
batch_input = np.random.randn(BATCH_SIZE, 3, 224, 224).astype(np.float32)
# execute tf model
tf_input = tf.convert_to_tensor(batch_input)
tf_output = tf_model(tf_input, training=False)
tf_output = tf_output[0].numpy()
# execute converted model
converted_model = converted_model.to(DEVICE)
converted_model = converted_model.eval()
torch_input = torch.from_numpy(batch_input).to(DEVICE)
torch_output = converted_model(torch_input)
torch_output = torch_output.detach().cpu().numpy()
# compare results
print("Max diff:", np.max(np.abs(tf_output - torch_output)))
# sample output:
# Max diff: 9.3877316e-07
The outputs are actually close enough to validate the converted model.
Model Structure
To get a feel for the structure of the converted model, we calculate the variety of trainable comparisons and compare it that of the unique model:
num_tf_params = sum([np.prod(v.shape) for v in tf_model.trainable_weights])
num_pyt_params = sum([p.numel()
for p in converted_model.parameters()
if p.requires_grad])
print(f"TensorFlow trainable parameters: {num_tf_params}")
print(f"PyTorch Trainable Parameters: {num_pyt_params:,}")
The difference within the variety of trainable parameters is profound, just 589,824 within the converted model in comparison with over 85 million in the unique model. Traversing the layers of the converted model results in that very same conclusion: The ONNX-based conversion has completely altered the model structure, rendering it essentially unrecognizable. There are numerous ramifications to this finding, including:
- Training/fine-tuning the converted model: Although we’ve shown that the converted model will be used for inference, the change in structure — particularly the undeniable fact that a number of the model parameters have been baked in, implies that we cannot use the converted model for training or fine-tuning.
- Applying pinpoint PyTorch optimizations to the model: The converted model consists of a really large variety of layers each representing a comparatively low-level operation. This greatly limits our ability to exchange inefficient operations with optimized PyTorch equivalents, akin to torch.nn.functional.scaled_dot_product_attention (SPDA).
Model Optimization
We have now already seen that our ability to access and modify model operations is proscribed, but there are numerous optimizations that we will apply that don’t require such access. Within the code block below, we apply PyTorch compilation and automatic mixed precision (AMP) and compare the resultant throughput to that of the TensorFlow model. For further context, we also test the runtime of the PyTorch version of the ViTForImageClassification model:
# Set tf mixed precision policy to bfloat16
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16')
# Set torch matmul precision to high
torch.set_float32_matmul_precision('high')
@tf.function
def tf_infer_fn(batch):
return tf_model(batch, training=False)
def get_torch_infer_fn(model):
def infer_fn(batch):
with torch.inference_mode(), torch.amp.autocast(
DEVICE,
dtype=torch.bfloat16,
enabled=DEVICE=='cuda'
):
output = model(batch)
return output
return infer_fn
def benchmark(infer_fn, batch):
# warm-up
for _ in range(20):
_ = infer_fn(batch)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
iters = 100
for _ in range(iters):
_ = infer_fn(batch)
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / iters
# assess throughput of TF model
avg_time = benchmark(tf_infer_fn, tf_input)
print(f"nTensorFlow average step time: {(avg_time):.4f}")
# assess throughput of converted model
torch_infer_fn = get_torch_infer_fn(converted_model)
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"nConverted model average step time: {(avg_time):.4f}")
# assess throughput of compiled model
torch_infer_fn = get_torch_infer_fn(torch.compile(converted_model))
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"nCompiled model average step time: {(avg_time):.4f}")
# assess throughput of torch ViT
from transformers import ViTForImageClassification
torch_model = ViTForImageClassification(vit_config).to(DEVICE)
torch_infer_fn = get_torch_infer_fn(torch_model)
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"nPyTorch ViT model average step time: {(avg_time):.4f}")
# assess throughput of compiled torch ViT
torch_infer_fn = get_torch_infer_fn(torch.compile(torch_model))
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"nCompiled ViT model average step time: {(avg_time):.4f}")
Note that originally PyTorch compilation fails on the converted model attributable to using torch.Size operator within the OnnxReshape layer. While this is definitely fixable (e.g., tuple([int(i) for i in shape])), it points to a deeper obstacle to optimization of the model: The reshape layer, which appears dozens of times within the model, treats shapes as PyTorch tensors residing on the GPU. Because of this each call requires detaching the form tensor from the graph and copying it to the CPU. The conclusion is that although the converted model is functionally accurate, its resultant definition just isn’t optimized for runtime performance. This will be seen from the step time results of the various model configurations:
The converted model is slower than the unique TensorFlow flow and significantly slower than PyTorch version of the ViT model.
Limitations
Although (within the case of our toy model) the ONNX-based conversion scheme works, it has numerous significant limitations:
- In the course of the conversion many parameters were baked into the model, limiting its use to inference workloads only.
- The ONNX conversion breaks the computation graph into low level operators in a way that makes it difficult to use and/or reap the good thing about some PyTorch optimizations.
- The reliance on ONNX implies that our conversion scheme will only work on ONNX-friendly models. It should not work on models that can’t be mapped to the usual ONNX operator set (e.g., models with dynamic control flow).
- The conversion scheme relies on the health and maintenance of a third-party library that just isn’t a part of the official ONNX offering.
Although the scheme works — at the least for inference workloads — it’s possible you’ll find the restrictions to be too restrictive to be used on your personal TensorFlow models. One possible option is to desert the ONNX-to-PyTorch conversion and perform inference using the ONNX Runtime library.
Model Conversion Via Keras3
Keras3 is a high-level deep learning API focused on maximizing the readability, maintainability, and ease of use of AI/ML applications. In a previous post, we evaluated Keras3 and highlighted its support for multiple backends. On this post we revisit its multi-framework support and assess whether this will be utilized for model conversion. The scheme we propose is to 1) migrate the present TensorFlow model to Keras3 after which 2) run the model with the Keras3 PyTorch backend.
Upgrading TensorFlow to Keras3
Contrary to the ONNX-based conversion scheme, our current solution may require some code changes to the TensorFlow model to migrate it to Keras3. While the documentation makes it sound easy, in practice the problem of the migration will depend greatly on the main points of the model implementation. Within the case of our toy model, HuggingFace explicitly enforces using the legacy tf-keras, stopping using Keras3. To implement our scheme, we want to 1) redefine the model without this restriction, and a pair of) replace native TensorFlow operators with Keras3 equivalents. The code block below incorporates a stripped-down version of the model, together with the required adjustments. To get a full grasp of the changes that were required, perform a side-by-side code comparison with the original model definition.
import math
import keras
HIDDEN_SIZE = 768
IMG_SIZE = 224
PATCH_SIZE = 16
ATTN_HEADS = 12
NUM_LAYERS = 12
INTER_SZ = 4*HIDDEN_SIZE
N_LABELS = 2
class TFViTEmbeddings(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.patch_embeddings = TFViTPatchEmbeddings()
num_patches = self.patch_embeddings.num_patches
self.cls_token = self.add_weight((1, 1, HIDDEN_SIZE))
self.position_embeddings = self.add_weight((1, num_patches+1,
HIDDEN_SIZE))
def call(self, pixel_values, training=False):
bs, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(pixel_values, training=training)
cls_tokens = keras.ops.repeat(self.cls_token, repeats=bs, axis=0)
embeddings = keras.ops.concatenate((cls_tokens, embeddings), axis=1)
embeddings = embeddings + self.position_embeddings
return embeddings
class TFViTPatchEmbeddings(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
patch_size = (PATCH_SIZE, PATCH_SIZE)
image_size = (IMG_SIZE, IMG_SIZE)
num_patches = (image_size[1]//patch_size[1]) *
(image_size[0]//patch_size[0])
self.patch_size = patch_size
self.num_patches = num_patches
self.projection = keras.layers.Conv2D(
filters=HIDDEN_SIZE,
kernel_size=patch_size,
strides=patch_size,
padding="valid",
data_format="channels_last"
)
def call(self, pixel_values, training=False):
bs, num_channels, height, width = pixel_values.shape
pixel_values = keras.ops.transpose(pixel_values, (0, 2, 3, 1))
projection = self.projection(pixel_values)
num_patches = (width // self.patch_size[1]) *
(height // self.patch_size[0])
embeddings = keras.ops.reshape(projection, (bs, num_patches, -1))
return embeddings
class TFViTSelfAttention(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.num_attention_heads = ATTN_HEADS
self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS)
self.all_head_size = ATTN_HEADS * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = keras.layers.Dense(self.all_head_size, name="query")
self.key = keras.layers.Dense(self.all_head_size, name="key")
self.value = keras.layers.Dense(self.all_head_size, name="value")
def transpose_for_scores(self, tensor, batch_size: int):
tensor = keras.ops.reshape(tensor, (batch_size, -1, ATTN_HEADS,
self.attention_head_size))
return keras.ops.transpose(tensor, [0, 2, 1, 3])
def call(self, hidden_states, training=False):
bs = hidden_states.shape[0]
mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, bs)
key_layer = self.transpose_for_scores(mixed_key_layer, bs)
value_layer = self.transpose_for_scores(mixed_value_layer, bs)
key_layer_T = keras.ops.transpose(key_layer, [0,1,3,2])
attention_scores = keras.ops.matmul(query_layer, key_layer_T)
dk = keras.ops.forged(self.sqrt_att_head_size,
dtype=attention_scores.dtype)
attention_scores = keras.ops.divide(attention_scores, dk)
attention_probs = keras.ops.softmax(attention_scores+1e-9, axis=-1)
attention_output = keras.ops.matmul(attention_probs, value_layer)
attention_output = keras.ops.transpose(attention_output,[0,2,1,3])
attention_output = keras.ops.reshape(attention_output,
(bs, -1, self.all_head_size))
return (attention_output,)
class TFViTSelfOutput(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(HIDDEN_SIZE)
def call(self, hidden_states, input_tensor, training = False):
return self.dense(inputs=hidden_states)
class TFViTAttention(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.self_attention = TFViTSelfAttention()
self.dense_output = TFViTSelfOutput()
def call(self, input_tensor, training = False):
self_outputs = self.self_attention(
hidden_states=input_tensor, training=training
)
attention_output = self.dense_output(
hidden_states=self_outputs[0],
input_tensor=input_tensor,
training=training
)
return (attention_output,)
class TFViTIntermediate(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(INTER_SZ)
self.intermediate_act_fn = keras.activations.gelu
def call(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class TFViTOutput(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.dense = keras.layers.Dense(HIDDEN_SIZE)
def call(self, hidden_states, input_tensor, training: bool = False):
hidden_states = self.dense(inputs=hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class TFViTLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.attention = TFViTAttention()
self.intermediate = TFViTIntermediate()
self.vit_output = TFViTOutput()
self.layernorm_before = keras.layers.LayerNormalization(
epsilon=1e-12
)
self.layernorm_after = keras.layers.LayerNormalization(
epsilon=1e-12
)
def call(self, hidden_states, training=False):
attention_outputs = self.attention(
input_tensor=self.layernorm_before(inputs=hidden_states),
training=training,
)
attention_output = attention_outputs[0]
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
intermediate_output = self.intermediate(layer_output)
layer_output = self.vit_output(
hidden_states=intermediate_output,
input_tensor=hidden_states,
training=training
)
outputs = (layer_output,)
return outputs
class TFViTEncoder(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.layer = [TFViTLayer(name=f"layer_{i}")
for i in range(NUM_LAYERS)]
def call(self, hidden_states, training=False):
for i, layer_module in enumerate(self.layer):
layer_outputs = layer_module(
hidden_states=hidden_states,
training=training,
)
hidden_states = layer_outputs[0]
return tuple([hidden_states])
class TFViTMainLayer(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.embeddings = TFViTEmbeddings()
self.encoder = TFViTEncoder()
self.layernorm = keras.layers.LayerNormalization(epsilon=1e-12)
def call(self, pixel_values, training=False):
embedding_output = self.embeddings(
pixel_values=pixel_values,
training=training,
)
encoder_outputs = self.encoder(
hidden_states=embedding_output,
training=training,
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(inputs=sequence_output)
return (sequence_output,)
class TFViTForImageClassification(keras.Model):
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
self.vit = TFViTMainLayer()
self.classifier = keras.layers.Dense(N_LABELS)
def call(self, pixel_values, training=False):
outputs = self.vit(pixel_values, training=training)
sequence_output = outputs[0]
logits = self.classifier(inputs=sequence_output[:, 0, :])
return (logits,)
TensorFlow to PyTorch Conversion
The conversion sequence appears within the code block below. As before, we validate the output of the resultant model in addition to the variety of trainable parameters.
# save weights of TensorFlow model
tf_model.save_weights("model_weights.h5")
import keras
keras.config.set_backend("torch")
from keras3_vit import TFViTForImageClassification as Keras3ViT
keras3_model = Keras3ViT()
# call model to initializate all layers
keras3_model(torch_input, training=False)
# load the weights from the TensorFlow model
keras3_model.load_weights("model_weights.h5")
# validate converted model
assert isinstance(keras3_model, torch.nn.Module)
keras3_model = keras3_model.to(DEVICE)
keras3_model = keras3_model.eval()
torch_output = keras3_model(torch_input, training=False)
torch_output = torch_output[0].detach().cpu().numpy()
print("Max diff:", np.max(np.abs(tf_output - torch_output)))
num_pyt_params = sum([p.numel()
for p in keras3_model.parameters()
if p.requires_grad])
print(f"Keras3 Trainable Parameters: {num_pyt_params:,}")
Training/Superb-tuning the Model
Contrary to the ONNX-converted model, the Keras3 model maintains the identical structure and trainable parameters. This permits for resuming training and/or finetuning on the converted model. This may either be done inside the Keras3 training framework or using a standard PyTorch training loop.
Optimizing Model Layers
Contrary to the ONNX-converted model, the coherence of the Keras3 model definition allows for easily modifying and optimizing the layer implementations. Within the code block below, we replace the present attention mechanism with PyTorch’s highly efficient SDPA operator.
from torch.nn.functional import scaled_dot_product_attention as sdpa
class TFViTSelfAttention(keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.num_attention_heads = ATTN_HEADS
self.attention_head_size = int(HIDDEN_SIZE / ATTN_HEADS)
self.all_head_size = ATTN_HEADS * self.attention_head_size
self.sqrt_att_head_size = math.sqrt(self.attention_head_size)
self.query = keras.layers.Dense(self.all_head_size, name="query")
self.key = keras.layers.Dense(self.all_head_size, name="key")
self.value = keras.layers.Dense(self.all_head_size, name="value")
def transpose_for_scores(self, tensor, batch_size: int):
tensor = keras.ops.reshape(tensor, (batch_size, -1, ATTN_HEADS,
self.attention_head_size))
return keras.ops.transpose(tensor, [0, 2, 1, 3])
def call(self, hidden_states, training=False):
bs = hidden_states.shape[0]
mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer, bs)
key_layer = self.transpose_for_scores(mixed_key_layer, bs)
value_layer = self.transpose_for_scores(mixed_value_layer, bs)
sdpa_output = sdpa(query_layer, key_layer, value_layer)
attention_output = keras.ops.transpose(sdpa_output,[0,2,1,3])
attention_output = keras.ops.reshape(attention_output,
(bs, -1, self.all_head_size))
return (attention_output,)
We using the identical benchmarking function from above to evaluate the impact of this optimization on the model’s runtime performance:
torch_infer_fn = get_torch_infer_fn(keras3_model)
avg_time = benchmark(torch_infer_fn, torch_input)
print(f"Keras3 converted model average step time: {(avg_time):.4f}")
The outcomes are captured within the table below:

Using the Keras3-based model conversion scheme, and applying the SDPA optimization, we’re in a position to speed up the model inference throughput by 22% in comparison with the unique TensorFlow model.
Model Compilation
One other optimization we would love to use is PyTorch compilation. Unfortunately (as of the time of this writing), PyTorch compilation in Keras3 is proscribed. Within the case of our toy model, each our try to apply torch.compile on to the model, in addition to setting the field of the Keras3 Model.compile function, failed. In each cases, the failure resulted from multiple recompilations that were triggered by the Keras3 internal machinery. While Keras3 grants access to the PyTorch ecosystem, its high-level abstraction might impose some limitations.
Limitations
Once more, we’ve a conversion scheme that works but has several limitations:
- The TensorFlow models have to be Keras3-compatible. The quantity of labor this may require will rely upon the main points of your model implementation. It could require some Keras layer customization.
- While the resultant model is a torch.nn.Module, it just isn’t a “pure” PyTorch model within the sense that it’s comprised of Keras3 layers and includes a variety of additional Keras3 code. This will require some adaptations to our PyTorch tooling and will impose some restrictions, as we saw once we tried to use PyTorch compilation.
- The answer relies on the health and maintenance of Keras3 and its support for the TensorFlow and PyTorch backends.
Summary
On this post we’ve proposed and assessed two methods for auto-conversion of legacy TensorFlow models to PyTorch. We summarize our findings in the next table.

Ultimately, the perfect approach, whether it’s considered one of the methods discussed here, manual conversion, an answer based on generative AI, or the choice to not perform conversion in any respect, will greatly rely upon the main points of the model and the situation.
