package ws import ( "fmt" "log" "net/http" "strings" mw "github.com/flatrender/gateway/internal/middleware" "github.com/gin-gonic/gin" "github.com/golang-jwt/jwt/v5" "github.com/google/uuid" "github.com/gorilla/websocket" ) var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, Subprotocols: []string{"flatrender.v1"}, } // RenderProgressProxy proxies WebSocket connections to the render service's REST polling endpoint // and streams progress events to the client via the WebSocket protocol. // // Connection: wss://gateway/ws/v1/render/{job_id}?token={jwt} // // The gateway validates JWT ownership, then opens a persistent proxy WS to the upstream // render service. In production the render service would expose its own WS; for now we // implement a polling bridge using the REST /progress endpoint. func RenderProgressProxy(renderUpstreamWS string, jwtSecret string) gin.HandlerFunc { return func(c *gin.Context) { jobID := c.Param("job_id") if _, err := uuid.Parse(jobID); err != nil { c.JSON(http.StatusBadRequest, gin.H{"code": "bad_request", "message": "invalid job_id"}) return } // Authenticate — token may come from query param or Authorization header tokenStr := c.Query("token") if tokenStr == "" { hdr := c.GetHeader("Authorization") if strings.HasPrefix(hdr, "Bearer ") { tokenStr = hdr[7:] } } if tokenStr == "" { c.Writer.WriteHeader(http.StatusUnauthorized) return } token, err := jwt.Parse(tokenStr, func(t *jwt.Token) (interface{}, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { return nil, jwt.ErrSignatureInvalid } return []byte(jwtSecret), nil }) if err != nil || !token.Valid { c.Writer.WriteHeader(http.StatusUnauthorized) return } claims, _ := token.Claims.(jwt.MapClaims) userID, _ := uuid.Parse(fmt.Sprintf("%v", claims["sub"])) // Upgrade the client connection clientConn, err := upgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { log.Printf("ws upgrade error: %v", err) return } defer clientConn.Close() // Connect to upstream render service WS upstreamURL := fmt.Sprintf("%s/ws/v1/render/%s?user_id=%s", renderUpstreamWS, jobID, userID) upstreamConn, _, err := websocket.DefaultDialer.Dial(upstreamURL, http.Header{ "Authorization": []string{"Bearer " + tokenStr}, }) if err != nil { // Upstream WS not available — send hello + close _ = clientConn.WriteJSON(gin.H{ "type": "error", "code": "UPSTREAM_UNAVAILABLE", "message": "render service WebSocket unavailable; use REST polling fallback", }) clientConn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(1011, "upstream unavailable")) return } defer upstreamConn.Close() // Bidirectional pipe errCh := make(chan error, 2) // Client → upstream go func() { for { mt, msg, err := clientConn.ReadMessage() if err != nil { errCh <- err return } if err := upstreamConn.WriteMessage(mt, msg); err != nil { errCh <- err return } } }() // Upstream → client go func() { for { mt, msg, err := upstreamConn.ReadMessage() if err != nil { errCh <- err return } if err := clientConn.WriteMessage(mt, msg); err != nil { errCh <- err return } } }() <-errCh } } // mw import alias used above var _ = mw.CtxUserID