During the Google Summer of Code (GSoC) 2024, I had the incredible opportunity to contribute to the Kubeflow open-source project by working on the integration of JAX with the Kubeflow Training Operator. The goal of this project was to provide a seamless and efficient way to run distributed computations on CPU using the JAX framework on Kubernetes. Throughout the summer, I collaborated with my mentors, Andrey Velichkevich, Yuki Iwai, Yuan Tang, and Shravan Achar to build out this feature by extending the Training Operator.
This blog post provides an overview of the project goals, key challenges, solutions implemented, and lessons learned during my journey.
JAX, a powerful ML framework developed by Google, is highly valued for its flexibility and performance in large-scale distributed computations, especially with its native support for automatic differentiation and hardware accelerators like GPUs and TPUs. The Kubeflow Training Operator is a popular Kubernetes component that allows users to run distributed ML training jobs across various frameworks (such as TensorFlow, PyTorch, and XGBoost). However, until now, it lacked direct support for JAX.
Create a Custom Resource for JAX (JaxJob):
We needed to introduce a new Kubernetes Custom Resource Definition (CRD) for JAX, called JAXJob
, that would allow users to define distributed JAX training jobs in Kubernetes clusters. This was crucial for enabling the integration of JAX into the Training Operator.
Update the Training Operator Controller:
The Training Operator controller had to be updated to support the new JAXJob
resource, handling the creation, scheduling, and management of distributed JAX training jobs on Kubernetes.
Enhance the Training Operator Python SDK:
We aimed to extend the Training Operator Python SDK to provide easy-to-use APIs for data scientists and ML practitioners to define and launch JAXJob
on Kubernetes, simplifying the process of running distributed JAX jobs.
The project required a strong understanding of the following technologies:
JAXJob
Custom ResourceThe first major task was to define a new Custom Resource for JAX jobs, similar to the existing TFJob
, PyTorchJob
, and XGBoostJob
. This required defining a Kubernetes CRD that would describe the specifications for a JAX distributed training job, such as the number of workers, resource allocation, and job configuration.
The JAXJob
CRD was designed to be flexible and compatible with other Kubernetes-based workflows. Here’s a basic example of a JAXJob
manifest:
apiVersion: "kubeflow.org/v1"
kind: JAXJob
metadata:
name: jaxjob-simple
spec:
jaxReplicaSpecs:
Worker:
replicas: 2
restartPolicy: OnFailure
template:
spec:
containers:
- name: jax-worker
image: sandipanify/jaxgoogle
command: ["python", "train.py"]
ports:
- containerPort: 6666
imagePullPolicy: Always
The next step was to update the Training Operator controller (written in Go) to recognize and manage JAXJob
resources. This required adding JAX-specific logic to handle job creation, scheduling, scaling, and monitoring.
The main controller logic involves watching for JAXJob
events (create, update, delete) and ensuring that the right resources (e.g., Pods and services) are spun up or down in response to the job’s lifecycle.
I followed the existing patterns in the Training Operator for other frameworks (such as PyTorch and XGBoost) and adapted them for JAX, ensuring consistency and reusability of the codebase.
To make this new functionality more accessible to users, I extended the Training Operator’s Python SDK. The SDK is widely used by data scientists to interact with Kubernetes resources programmatically, and adding support for JAX was a crucial step toward usability.
This SDK enhancement bridges the gap between data scientists and Kubernetes infrastructure, allowing them to focus on model development rather than cluster management.
Testing was a critical aspect of the project. I implemented both unit and integration tests to ensure that the JaxJob
CRD and the Training Operator controller functioned correctly under different scenarios, such as node failures, pod restarts, and resource contention.
By the end of the project, the following milestones were successfully achieved:
JAXJob
creation in the Python SDK, making it easier for end-users to interact with JAX jobs programmatically.Throughout this project, I gained valuable insights into distributed systems, Kubernetes resource management, and the inner workings of machine learning frameworks like JAX. Some key takeaways include:
Integrating JAX with the Kubeflow Training Operator has been a challenging but rewarding experience. The project successfully enables distributed training for JAX on Kubernetes, providing an easy-to-use interface for data scientists and machine learning engineers.
I am grateful to my mentors — Andrey Velichkevich, Yuki Iwai, Yuan Tang, and Shravan Achar — for their support and guidance throughout the summer.
I look forward to seeing how this feature evolves and benefits the Kubeflow community in the future.