diff --git a/pro/controllers/jit.go b/pro/controllers/jit.go index 53577dd4..669b980f 100644 --- a/pro/controllers/jit.go +++ b/pro/controllers/jit.go @@ -4,6 +4,7 @@ import ( "encoding/json" "errors" "net/http" + "strconv" "time" "github.com/gorilla/mux" @@ -85,15 +86,42 @@ func handleJIT(w http.ResponseWriter, r *http.Request) { // handleJITGet - handles GET requests for JIT status/requests func handleJITGet(w http.ResponseWriter, r *http.Request, networkID string, user *models.User) { - statusFilter := r.URL.Query().Get("status") // "pending", "approved", "denied", "expired", or empty for all - requests, err := proLogic.GetNetworkJITRequests(networkID, statusFilter) + + // Parse pagination parameters (default to 0, db.SetPagination will apply defaults) + page, _ := strconv.Atoi(r.URL.Query().Get("page")) + pageSize, _ := strconv.Atoi(r.URL.Query().Get("per_page")) + + // Apply defaults if not provided (matching db.SetPagination logic) + if page < 1 { + page = 1 + } + if pageSize < 1 || pageSize > 100 { + pageSize = 10 + } + + ctx := db.WithContext(r.Context()) + requests, total, err := proLogic.GetNetworkJITRequestsPaginated(ctx, networkID, statusFilter, page, pageSize) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - logic.ReturnSuccessResponseWithJson(w, r, requests, "fetched JIT requests") + // Calculate pagination metadata + totalPages := (int(total) + pageSize - 1) / pageSize + if totalPages == 0 { + totalPages = 1 + } + + response := map[string]interface{}{ + "data": requests, + "page": page, + "per_page": pageSize, + "total": total, + "total_pages": totalPages, + } + + logic.ReturnSuccessResponseWithJson(w, r, response, "fetched JIT requests") } // handleJITPost - handles POST requests for JIT operations diff --git a/pro/logic/jit.go b/pro/logic/jit.go index f0902845..18fb3559 100644 --- a/pro/logic/jit.go +++ b/pro/logic/jit.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "sort" "time" "github.com/google/uuid" @@ -309,37 +310,82 @@ type JITRequestWithGrant struct { // statusFilter can be: "pending", "approved", "denied", "expired", or "" for all func GetNetworkJITRequests(networkID string, statusFilter string) ([]JITRequestWithGrant, error) { ctx := db.WithContext(context.Background()) + requests, _, err := GetNetworkJITRequestsPaginated(ctx, networkID, statusFilter, 1, 0) + return requests, err +} +// GetNetworkJITRequestsPaginated - gets paginated JIT requests for a network, optionally filtered by status +// statusFilter can be: "pending", "approved", "denied", "expired", or "" for all +// page and pageSize control pagination. db.SetPagination will apply defaults (page=1, pageSize=10) if values are invalid. +// Returns: requests, total count, error +func GetNetworkJITRequestsPaginated(ctx context.Context, networkID string, statusFilter string, page, pageSize int) ([]JITRequestWithGrant, int64, error) { request := schema.JITRequest{NetworkID: networkID} var requests []schema.JITRequest + var total int64 var err error - // If no filter, return all requests + // Always set up pagination context - db.SetPagination handles defaults (page=1, pageSize=10) + paginatedCtx := db.SetPagination(ctx, page, pageSize) + + // Get total count for pagination metadata if statusFilter == "" || statusFilter == "all" { - requests, err = request.ListByNetwork(ctx) + total, err = request.CountByNetwork(ctx) if err != nil { - return nil, err + return nil, 0, err + } + requests, err = request.ListByNetwork(paginatedCtx) + if err != nil { + return nil, 0, err } } else if statusFilter == "expired" { // Handle expired filter (approved requests that have expired) + // For expired filter, we need to get all and filter in memory, then apply pagination allRequests, err := request.ListByNetwork(ctx) if err != nil { - return nil, err + return nil, 0, err } now := time.Now().UTC() + var filteredRequests []schema.JITRequest for _, req := range allRequests { // Include requests with status "expired" or "approved" requests that have passed expiration if req.Status == "expired" || (req.Status == "approved" && !req.ExpiresAt.IsZero() && now.After(req.ExpiresAt)) { - requests = append(requests, req) + filteredRequests = append(filteredRequests, req) } } + + // Sort by requested_at DESC (most recent first) + sort.Slice(filteredRequests, func(i, j int) bool { + return filteredRequests[i].RequestedAt.After(filteredRequests[j].RequestedAt) + }) + + total = int64(len(filteredRequests)) + + // Apply pagination manually for expired filter + if pageSize > 0 { + offset := (page - 1) * pageSize + end := offset + pageSize + if offset >= len(filteredRequests) { + requests = []schema.JITRequest{} + } else { + if end > len(filteredRequests) { + end = len(filteredRequests) + } + requests = filteredRequests[offset:end] + } + } else { + requests = filteredRequests + } } else { // Filter by status: pending, approved, or denied - requests, err = request.ListByStatusAndNetwork(ctx, statusFilter) + total, err = request.CountByStatusAndNetwork(ctx, statusFilter) if err != nil { - return nil, err + return nil, 0, err + } + requests, err = request.ListByStatusAndNetwork(paginatedCtx, statusFilter) + if err != nil { + return nil, 0, err } } @@ -361,7 +407,7 @@ func GetNetworkJITRequests(networkID string, statusFilter string) ([]JITRequestW result = append(result, enriched) } - return result, nil + return result, total, nil } // GetUserJITStatus - gets JIT status for a user on a network diff --git a/schema/jit_request.go b/schema/jit_request.go index b0f66476..d3e88e82 100644 --- a/schema/jit_request.go +++ b/schema/jit_request.go @@ -46,7 +46,7 @@ func (r *JITRequest) Delete(ctx context.Context) error { func (r *JITRequest) ListByNetwork(ctx context.Context) ([]JITRequest, error) { var requests []JITRequest - err := db.FromContext(ctx).Table(r.Table()).Where("network_id = ?", r.NetworkID).Find(&requests).Error + err := db.FromContext(ctx).Table(r.Table()).Where("network_id = ?", r.NetworkID).Order("requested_at DESC").Find(&requests).Error return requests, err } @@ -64,6 +64,18 @@ func (r *JITRequest) ListPendingByNetwork(ctx context.Context) ([]JITRequest, er func (r *JITRequest) ListByStatusAndNetwork(ctx context.Context, status string) ([]JITRequest, error) { var requests []JITRequest - err := db.FromContext(ctx).Table(r.Table()).Where("network_id = ? AND status = ?", r.NetworkID, status).Find(&requests).Error + err := db.FromContext(ctx).Table(r.Table()).Where("network_id = ? AND status = ?", r.NetworkID, status).Order("requested_at DESC").Find(&requests).Error return requests, err } + +func (r *JITRequest) CountByNetwork(ctx context.Context) (int64, error) { + var count int64 + err := db.FromContext(ctx).Table(r.Table()).Where("network_id = ?", r.NetworkID).Count(&count).Error + return count, err +} + +func (r *JITRequest) CountByStatusAndNetwork(ctx context.Context, status string) (int64, error) { + var count int64 + err := db.FromContext(ctx).Table(r.Table()).Where("network_id = ? AND status = ?", r.NetworkID, status).Count(&count).Error + return count, err +}