Hacker Newsnew | past | comments | ask | show | jobs | submitlogin

Impressive, but only forward pass.


I think the completeness and self-contained-ness more than offsets the limited scope. One of the problems in the ML field is rapidly multiplying logistical complexity, and I appreciate an example that is (somewhat) functional but simple enough to fit on a postcard and using very basic components.


It's an excellent learning tool :) Doing the backward pass in the same style would be a great tool for teaching.


just replace the numpy code with jax.numpy as you should have a fully differentiable model ready for training!


For someone not familiar with jax, if I do the suggested replacement. What'd be the little extra code to make it do the backward pass? Or is it all automatic and we literally would not need extra lines of code?


Backprop is just an implementation detail when doing automatic differentiation, basically setting up how you would apply the chain rule to your problem.

JAX is able to differentiate arbitrary python code (so long as it uses JAX for the numeric stuff) automatically so the backprop is abstracted away.

If you have the forward model written, to train it all you have to do with wrap it in whatever loss function you want, and the use JAX's `grad` with respect to the model parameters and you can use that to find the optimum using your favorite gradient optimization algorithm.

This is why JAX is so awesome. Differentiable programming means you only have to think about problems in terms of the forward pass and then you can trivially get the derivative of that function without having to worry about the implementation details.


I haven't heard about JAX before, but been tinkering in pytorch. Would I also be able to switch the use of np arrays here to torch, and then do .backwards() and get kinda the same benefits of JAX, or how does it differ in this regard?




Guidelines | FAQ | Lists | API | Security | Legal | Apply to YC | Contact

Search: