diff --git a/.gitignore b/.gitignore index 03a21e9..8ee567c 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ .env rate hwserver +tencent_ocr \ No newline at end of file diff --git a/.history/.gitignore_20250115154807 b/.history/.gitignore_20250115154807 new file mode 100644 index 0000000..03a21e9 --- /dev/null +++ b/.history/.gitignore_20250115154807 @@ -0,0 +1,3 @@ +.env +rate +hwserver diff --git a/.history/.gitignore_20250115163823 b/.history/.gitignore_20250115163823 new file mode 100644 index 0000000..8ee567c --- /dev/null +++ b/.history/.gitignore_20250115163823 @@ -0,0 +1,4 @@ +.env +rate +hwserver +tencent_ocr \ No newline at end of file diff --git a/.history/cmd/main_20250115161111.go b/.history/cmd/main_20250115161111.go new file mode 100644 index 0000000..80a7a0d --- /dev/null +++ b/.history/cmd/main_20250115161111.go @@ -0,0 +1,66 @@ +package main + +import ( + "log" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/config" + "tencent_ocr/pkg/handler" +) + +func main() { + // Load configuration + cfg, err := config.LoadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + // Initialize services + geminiService, err := service.NewGeminiService(cfg.GeminiAPIKey) + if err != nil { + log.Fatal(err) + } + defer geminiService.Close() + + ocrService := handler.NewOCRService( + cfg.TencentSecretID, + cfg.TencentSecretKey, + geminiService, + ) + + // Initialize handlers + ocrHandler := handler.NewOCRHandler( + cfg.TencentSecretID, + cfg.TencentSecretKey, + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + rateHandler := handler.NewRateHandler( + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + uploadHandler := handler.NewUploadHandler( + cfg.R2AccessKey, + cfg.R2SecretKey, + cfg.R2Bucket, + cfg.R2Endpoint, + cfg.R2CustomDomain, + ocrService, + geminiService, + ) + + // Setup Gin router + r := gin.Default() + + // Register routes + r.POST("/ocr", ocrHandler.HandleOCR) + r.POST("/rate", rateHandler.HandleRate) + // upload file to server + r.POST("/upload", uploadHandler.HandleUpload) + + // Start server + if err := r.Run("localhost:8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } +} \ No newline at end of file diff --git a/.history/cmd/main_20250115161117.go b/.history/cmd/main_20250115161117.go new file mode 100644 index 0000000..fb33427 --- /dev/null +++ b/.history/cmd/main_20250115161117.go @@ -0,0 +1,67 @@ +package main + +import ( + "log" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/config" + "tencent_ocr/pkg/handler" + "tencent_ocr/pkg/service" +) + +func main() { + // Load configuration + cfg, err := config.LoadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + // Initialize services + geminiService, err := service.NewGeminiService(cfg.GeminiAPIKey) + if err != nil { + log.Fatal(err) + } + defer geminiService.Close() + + ocrService := handler.NewOCRService( + cfg.TencentSecretID, + cfg.TencentSecretKey, + geminiService, + ) + + // Initialize handlers + ocrHandler := handler.NewOCRHandler( + cfg.TencentSecretID, + cfg.TencentSecretKey, + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + rateHandler := handler.NewRateHandler( + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + uploadHandler := handler.NewUploadHandler( + cfg.R2AccessKey, + cfg.R2SecretKey, + cfg.R2Bucket, + cfg.R2Endpoint, + cfg.R2CustomDomain, + ocrService, + geminiService, + ) + + // Setup Gin router + r := gin.Default() + + // Register routes + r.POST("/ocr", ocrHandler.HandleOCR) + r.POST("/rate", rateHandler.HandleRate) + // upload file to server + r.POST("/upload", uploadHandler.HandleUpload) + + // Start server + if err := r.Run("localhost:8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } +} \ No newline at end of file diff --git a/.history/cmd/main_20250115162857.go b/.history/cmd/main_20250115162857.go new file mode 100644 index 0000000..a0f50fc --- /dev/null +++ b/.history/cmd/main_20250115162857.go @@ -0,0 +1,95 @@ +package main + +import ( + "context" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" + + "tencent_ocr/pkg/config" + "tencent_ocr/pkg/handler" + "tencent_ocr/pkg/middleware" + "tencent_ocr/pkg/service" +) + +func main() { + // Load configuration + cfg, err := config.LoadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Initialize services + ocrService, err := service.NewOCRService(cfg.TencentSecretID, cfg.TencentSecretKey) + if err != nil { + log.Fatalf("Failed to initialize OCR service: %v", err) + } + defer ocrService.Close() + + geminiService, err := service.NewGeminiService(cfg.GeminiAPIKey) + if err != nil { + log.Fatalf("Failed to initialize Gemini service: %v", err) + } + defer geminiService.Close() + + uploadService, err := service.NewUploadService( + cfg.R2AccessKey, + cfg.R2SecretKey, + cfg.R2Bucket, + cfg.R2Endpoint, + cfg.R2CustomDomain, + ) + if err != nil { + log.Fatalf("Failed to initialize upload service: %v", err) + } + defer uploadService.Close() + + // Initialize handlers + ocrHandler := handler.NewOCRHandler(ocrService, geminiService) + rateHandler := handler.NewRateHandler(geminiService) + uploadHandler := handler.NewUploadHandler(uploadService, ocrService, geminiService) + + // Setup Gin router + r := gin.Default() + + // Add middleware + r.Use(middleware.APIKeyAuth(cfg.APIKey)) + r.Use(middleware.RateLimit(rate.Limit(10), 20)) // 10 requests per second with burst of 20 + + // Register routes + r.POST("/ocr", ocrHandler.HandleOCR) + r.POST("/rate", rateHandler.HandleRate) + r.POST("/upload", uploadHandler.HandleUpload) + + // Create server with graceful shutdown + srv := &http.Server{ + Addr: "localhost:8080", + Handler: r, + } + + // Graceful shutdown + go func() { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("Shutting down server...") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + log.Fatal("Server forced to shutdown:", err) + } + }() + + log.Println("Server starting on localhost:8080") + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Failed to start server: %v", err) + } +} \ No newline at end of file diff --git a/.history/cmd/main_20250115162917.go b/.history/cmd/main_20250115162917.go new file mode 100644 index 0000000..a0f50fc --- /dev/null +++ b/.history/cmd/main_20250115162917.go @@ -0,0 +1,95 @@ +package main + +import ( + "context" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" + + "tencent_ocr/pkg/config" + "tencent_ocr/pkg/handler" + "tencent_ocr/pkg/middleware" + "tencent_ocr/pkg/service" +) + +func main() { + // Load configuration + cfg, err := config.LoadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Initialize services + ocrService, err := service.NewOCRService(cfg.TencentSecretID, cfg.TencentSecretKey) + if err != nil { + log.Fatalf("Failed to initialize OCR service: %v", err) + } + defer ocrService.Close() + + geminiService, err := service.NewGeminiService(cfg.GeminiAPIKey) + if err != nil { + log.Fatalf("Failed to initialize Gemini service: %v", err) + } + defer geminiService.Close() + + uploadService, err := service.NewUploadService( + cfg.R2AccessKey, + cfg.R2SecretKey, + cfg.R2Bucket, + cfg.R2Endpoint, + cfg.R2CustomDomain, + ) + if err != nil { + log.Fatalf("Failed to initialize upload service: %v", err) + } + defer uploadService.Close() + + // Initialize handlers + ocrHandler := handler.NewOCRHandler(ocrService, geminiService) + rateHandler := handler.NewRateHandler(geminiService) + uploadHandler := handler.NewUploadHandler(uploadService, ocrService, geminiService) + + // Setup Gin router + r := gin.Default() + + // Add middleware + r.Use(middleware.APIKeyAuth(cfg.APIKey)) + r.Use(middleware.RateLimit(rate.Limit(10), 20)) // 10 requests per second with burst of 20 + + // Register routes + r.POST("/ocr", ocrHandler.HandleOCR) + r.POST("/rate", rateHandler.HandleRate) + r.POST("/upload", uploadHandler.HandleUpload) + + // Create server with graceful shutdown + srv := &http.Server{ + Addr: "localhost:8080", + Handler: r, + } + + // Graceful shutdown + go func() { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("Shutting down server...") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + log.Fatal("Server forced to shutdown:", err) + } + }() + + log.Println("Server starting on localhost:8080") + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Failed to start server: %v", err) + } +} \ No newline at end of file diff --git a/cmd/server/main.go b/.history/cmd/server/main_20250115141957.go similarity index 100% rename from cmd/server/main.go rename to .history/cmd/server/main_20250115141957.go diff --git a/.history/cmd/server/main_20250115161032.go b/.history/cmd/server/main_20250115161032.go new file mode 100644 index 0000000..5d5bee2 --- /dev/null +++ b/.history/cmd/server/main_20250115161032.go @@ -0,0 +1,52 @@ +package main + +import ( + "log" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/config" + "tencenthw/pkg/handler" +) + +func main() { + // Load configuration + cfg, err := config.LoadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Initialize handlers + ocrHandler := handler.NewOCRHandler( + cfg.TencentSecretID, + cfg.TencentSecretKey, + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + rateHandler := handler.NewRateHandler( + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + uploadHandler := handler.NewUploadHandler( + cfg.R2AccessKey, + cfg.R2SecretKey, + cfg.R2Bucket, + cfg.R2Endpoint, + cfg.R2CustomDomain, + ) + + // Setup Gin router + r := gin.Default() + + // Register routes + r.POST("/ocr", ocrHandler.HandleOCR) + r.POST("/rate", rateHandler.HandleRate) + // upload file to server + r.POST("/upload", uploadHandler.HandleUpload) + + // Start server + if err := r.Run("localhost:8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } +} \ No newline at end of file diff --git a/.history/cmd/server/main_20250115161156.go b/.history/cmd/server/main_20250115161156.go new file mode 100644 index 0000000..5d5bee2 --- /dev/null +++ b/.history/cmd/server/main_20250115161156.go @@ -0,0 +1,52 @@ +package main + +import ( + "log" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/config" + "tencenthw/pkg/handler" +) + +func main() { + // Load configuration + cfg, err := config.LoadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Initialize handlers + ocrHandler := handler.NewOCRHandler( + cfg.TencentSecretID, + cfg.TencentSecretKey, + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + rateHandler := handler.NewRateHandler( + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + uploadHandler := handler.NewUploadHandler( + cfg.R2AccessKey, + cfg.R2SecretKey, + cfg.R2Bucket, + cfg.R2Endpoint, + cfg.R2CustomDomain, + ) + + // Setup Gin router + r := gin.Default() + + // Register routes + r.POST("/ocr", ocrHandler.HandleOCR) + r.POST("/rate", rateHandler.HandleRate) + // upload file to server + r.POST("/upload", uploadHandler.HandleUpload) + + // Start server + if err := r.Run("localhost:8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } +} \ No newline at end of file diff --git a/.history/cmd/server/main_20250115161351.go b/.history/cmd/server/main_20250115161351.go new file mode 100644 index 0000000..6c068b3 --- /dev/null +++ b/.history/cmd/server/main_20250115161351.go @@ -0,0 +1,52 @@ +package main + +import ( + "log" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/config" + "tencent_ocr/pkg/handler" +) + +func main() { + // Load configuration + cfg, err := config.LoadConfig() + if err != nil { + log.Fatalf("Failed to load configuration: %v", err) + } + + // Initialize handlers + ocrHandler := handler.NewOCRHandler( + cfg.TencentSecretID, + cfg.TencentSecretKey, + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + rateHandler := handler.NewRateHandler( + cfg.GeminiAPIKey, + cfg.APIKey, + ) + + uploadHandler := handler.NewUploadHandler( + cfg.R2AccessKey, + cfg.R2SecretKey, + cfg.R2Bucket, + cfg.R2Endpoint, + cfg.R2CustomDomain, + ) + + // Setup Gin router + r := gin.Default() + + // Register routes + r.POST("/ocr", ocrHandler.HandleOCR) + r.POST("/rate", rateHandler.HandleRate) + // upload file to server + r.POST("/upload", uploadHandler.HandleUpload) + + // Start server + if err := r.Run("localhost:8080"); err != nil { + log.Fatalf("Failed to start server: %v", err) + } +} \ No newline at end of file diff --git a/.history/pkg/errors/errors_20250115162803.go b/.history/pkg/errors/errors_20250115162803.go new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/.history/pkg/errors/errors_20250115162803.go @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/.history/pkg/errors/errors_20250115162807.go b/.history/pkg/errors/errors_20250115162807.go new file mode 100644 index 0000000..2190bd5 --- /dev/null +++ b/.history/pkg/errors/errors_20250115162807.go @@ -0,0 +1,82 @@ +package errors + +import ( + "fmt" +) + +type Error struct { + Message string + Cause error + Type ErrorType +} + +type ErrorType int + +const ( + ErrorTypeUnknown ErrorType = iota + ErrorTypeClient + ErrorTypeServer +) + +func (e *Error) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %v", e.Message, e.Cause) + } + return e.Message +} + +func (e *Error) Unwrap() error { + return e.Cause +} + +func New(message string) error { + return &Error{ + Message: message, + Type: ErrorTypeUnknown, + } +} + +func Wrap(err error, message string) error { + if err == nil { + return nil + } + return &Error{ + Message: message, + Cause: err, + Type: ErrorTypeUnknown, + } +} + +func NewClientError(message string) error { + return &Error{ + Message: message, + Type: ErrorTypeClient, + } +} + +func NewServerError(message string) error { + return &Error{ + Message: message, + Type: ErrorTypeServer, + } +} + +func IsClientError(err error) bool { + if err == nil { + return false + } + if e, ok := err.(*Error); ok { + return e.Type == ErrorTypeClient + } + return false +} + +func IsServerError(err error) bool { + if err == nil { + return false + } + if e, ok := err.(*Error); ok { + return e.Type == ErrorTypeServer + } + return false +} \ No newline at end of file diff --git a/.history/pkg/errors/errors_20250115162825.go b/.history/pkg/errors/errors_20250115162825.go new file mode 100644 index 0000000..2190bd5 --- /dev/null +++ b/.history/pkg/errors/errors_20250115162825.go @@ -0,0 +1,82 @@ +package errors + +import ( + "fmt" +) + +type Error struct { + Message string + Cause error + Type ErrorType +} + +type ErrorType int + +const ( + ErrorTypeUnknown ErrorType = iota + ErrorTypeClient + ErrorTypeServer +) + +func (e *Error) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %v", e.Message, e.Cause) + } + return e.Message +} + +func (e *Error) Unwrap() error { + return e.Cause +} + +func New(message string) error { + return &Error{ + Message: message, + Type: ErrorTypeUnknown, + } +} + +func Wrap(err error, message string) error { + if err == nil { + return nil + } + return &Error{ + Message: message, + Cause: err, + Type: ErrorTypeUnknown, + } +} + +func NewClientError(message string) error { + return &Error{ + Message: message, + Type: ErrorTypeClient, + } +} + +func NewServerError(message string) error { + return &Error{ + Message: message, + Type: ErrorTypeServer, + } +} + +func IsClientError(err error) bool { + if err == nil { + return false + } + if e, ok := err.(*Error); ok { + return e.Type == ErrorTypeClient + } + return false +} + +func IsServerError(err error) bool { + if err == nil { + return false + } + if e, ok := err.(*Error); ok { + return e.Type == ErrorTypeServer + } + return false +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115161506.go b/.history/pkg/handler/ocr_20250115161506.go new file mode 100644 index 0000000..1b47e1a --- /dev/null +++ b/.history/pkg/handler/ocr_20250115161506.go @@ -0,0 +1,127 @@ +package handler + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "google.golang.org/api/option" + "tencent_ocr/pkg/service" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115161810.go b/.history/pkg/handler/ocr_20250115161810.go new file mode 100644 index 0000000..7fb35cb --- /dev/null +++ b/.history/pkg/handler/ocr_20250115161810.go @@ -0,0 +1,125 @@ +package handler + +import ( + "context" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "google.golang.org/api/option" + "tencent_ocr/pkg/service" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115161816.go b/.history/pkg/handler/ocr_20250115161816.go new file mode 100644 index 0000000..23c00d0 --- /dev/null +++ b/.history/pkg/handler/ocr_20250115161816.go @@ -0,0 +1,124 @@ +package handler + +import ( + "context" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "google.golang.org/api/option" + "tencent_ocr/pkg/service" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115161825.go b/.history/pkg/handler/ocr_20250115161825.go new file mode 100644 index 0000000..620e573 --- /dev/null +++ b/.history/pkg/handler/ocr_20250115161825.go @@ -0,0 +1,123 @@ +package handler + +import ( + "context" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "tencent_ocr/pkg/service" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115162124.go b/.history/pkg/handler/ocr_20250115162124.go new file mode 100644 index 0000000..5b7ed34 --- /dev/null +++ b/.history/pkg/handler/ocr_20250115162124.go @@ -0,0 +1,124 @@ +package handler + +import ( + "context" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + + "tencent_ocr/pkg/service" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115162738.go b/.history/pkg/handler/ocr_20250115162738.go new file mode 100644 index 0000000..8ee44f9 --- /dev/null +++ b/.history/pkg/handler/ocr_20250115162738.go @@ -0,0 +1,86 @@ +package handler + +import ( + "context" + "net/http" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" + "tencent_ocr/pkg/middleware" + "tencent_ocr/pkg/errors" +) + +type OCRHandler struct { + ocrService *service.OCRService + geminiService *service.GeminiService +} + +func NewOCRHandler(ocrService *service.OCRService, geminiService *service.GeminiService) *OCRHandler { + return &OCRHandler{ + ocrService: ocrService, + geminiService: geminiService, + } +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +func (h *OCRHandler) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Error: "Invalid request format", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Error: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + status := http.StatusInternalServerError + if errors.IsClientError(err) { + status = http.StatusBadRequest + } + c.JSON(status, OCRResponse{ + Success: false, + Error: err.Error(), + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Error: "Text processing failed: " + err.Error(), + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115162804.go b/.history/pkg/handler/ocr_20250115162804.go new file mode 100644 index 0000000..8ee44f9 --- /dev/null +++ b/.history/pkg/handler/ocr_20250115162804.go @@ -0,0 +1,86 @@ +package handler + +import ( + "context" + "net/http" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" + "tencent_ocr/pkg/middleware" + "tencent_ocr/pkg/errors" +) + +type OCRHandler struct { + ocrService *service.OCRService + geminiService *service.GeminiService +} + +func NewOCRHandler(ocrService *service.OCRService, geminiService *service.GeminiService) *OCRHandler { + return &OCRHandler{ + ocrService: ocrService, + geminiService: geminiService, + } +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +func (h *OCRHandler) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Error: "Invalid request format", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Error: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + status := http.StatusInternalServerError + if errors.IsClientError(err) { + status = http.StatusBadRequest + } + c.JSON(status, OCRResponse{ + Success: false, + Error: err.Error(), + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Error: "Text processing failed: " + err.Error(), + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115163423.go b/.history/pkg/handler/ocr_20250115163423.go new file mode 100644 index 0000000..ef114bb --- /dev/null +++ b/.history/pkg/handler/ocr_20250115163423.go @@ -0,0 +1,84 @@ +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" + "tencent_ocr/pkg/errors" +) + +type OCRHandler struct { + ocrService *service.OCRService + geminiService *service.GeminiService +} + +func NewOCRHandler(ocrService *service.OCRService, geminiService *service.GeminiService) *OCRHandler { + return &OCRHandler{ + ocrService: ocrService, + geminiService: geminiService, + } +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +func (h *OCRHandler) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Error: "Invalid request format", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Error: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + status := http.StatusInternalServerError + if errors.IsClientError(err) { + status = http.StatusBadRequest + } + c.JSON(status, OCRResponse{ + Success: false, + Error: err.Error(), + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Error: "Text processing failed: " + err.Error(), + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115163602.go b/.history/pkg/handler/ocr_20250115163602.go new file mode 100644 index 0000000..ef114bb --- /dev/null +++ b/.history/pkg/handler/ocr_20250115163602.go @@ -0,0 +1,84 @@ +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" + "tencent_ocr/pkg/errors" +) + +type OCRHandler struct { + ocrService *service.OCRService + geminiService *service.GeminiService +} + +func NewOCRHandler(ocrService *service.OCRService, geminiService *service.GeminiService) *OCRHandler { + return &OCRHandler{ + ocrService: ocrService, + geminiService: geminiService, + } +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +func (h *OCRHandler) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Error: "Invalid request format", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Error: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + status := http.StatusInternalServerError + if errors.IsClientError(err) { + status = http.StatusBadRequest + } + c.JSON(status, OCRResponse{ + Success: false, + Error: err.Error(), + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Error: "Text processing failed: " + err.Error(), + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/rate_20250115161526.go b/.history/pkg/handler/rate_20250115161526.go new file mode 100644 index 0000000..28ddbde --- /dev/null +++ b/.history/pkg/handler/rate_20250115161526.go @@ -0,0 +1,177 @@ +package handler + +import ( + "net/http" + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/option" + "encoding/json" + "strings" + "tencent_ocr/pkg/service" +) + +type RateHandler struct { + geminiAPIKey string + apiKey string +} + +type RateRequest struct { + Content string `json:"content" binding:"required"` + Criteria string `json:"criteria"` + WritingRequirement string `json:"writing_requirement"` + APIKey string `json:"apikey" binding:"required"` +} + +type RateResponse struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + DetailedReview string `json:"detailed_review"` + Success bool `json:"success"` +} + +func NewRateHandler(geminiAPIKey, apiKey string) *RateHandler { + return &RateHandler{ + geminiAPIKey: geminiAPIKey, + apiKey: apiKey, + } +} + +func (h *RateHandler) HandleRate(c *gin.Context) { + var req RateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, RateResponse{ + Success: false, + }) + return + } + + // Validate API key + if req.APIKey != h.apiKey { + c.JSON(http.StatusUnauthorized, RateResponse{ + Success: false, + }) + return + } + + // Initialize Gemini client + ctx := c.Request.Context() + client, err := genai.NewClient(ctx, option.WithAPIKey(h.geminiAPIKey)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + defer client.Close() + + // Prepare criteria + criteria := req.Criteria + if criteria == "" { + criteria = `你是一名语文老师。你正在给学生的作文打分。根据以下中考作文评分标准,给作文打分。 +## 评分总分值:100分。 +### 88-100分 符合题意;写作目的和对象明确;思考充分,立意深刻,感情真挚;选材精当,内容充实;中心突出,条理清晰;表达准确,语言流畅。 +### 75-87分 符合题意;写作目的和对象较明确;思考较充分,立意清楚,感情真实;选材合理,内容具体;中心明确,有一定条理;表达较准确,语言通畅。 +### 60-74分 符合题意;写作目的和对象较模糊;有一定思考,感情真实;有一定内容;结构基本完整;语言尚通畅。 +### 60分以下 不符合题意;缺乏写作目的和对象;基本没有思考,感情虚假;内容空洞;结构混乱;不成篇。` + } + writing_requirement := req.WritingRequirement + if writing_requirement == "" { + writing_requirement = "写一篇不少于600字的作文,体裁不限。" + } + + // 规定输出格式是json,包含rate, summary, detailed_review,放入prompt的最后 + format := `请按照以下JSON格式输出: +{ + "rate": 分数, // 最多100分制的分数,int类型 + "summary": "总体评价", // 100字以内的总体评价,string类型 + "detailed_review": "详细点评" // 300字以内的详细点评,包含优点和建议,string类型 +}` + // Prepare prompt + prompt := "作文要求:\n" + writing_requirement + "\n\n" + "评分标准:\n" + criteria + format + "\n\n" + "\n\n作文内容:\n" + req.Content + + // Generate content + model := client.GenerativeModel("gemini-2.0-flash-exp") + resp, err := model.GenerateContent(ctx, genai.Text(prompt)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + + if len(resp.Candidates) > 0 && len(resp.Candidates[0].Content.Parts) > 0 { + if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { + // Parse the response to extract rate, summary, and detailed review + result := parseRateResponse(string(textPart)) + + c.JSON(http.StatusOK, RateResponse{ + Rate: result.Rate, + Summary: result.Summary, + DetailedReview: result.Detailed, + Success: true, + }) + return + } + } + + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) +} + +type rateResult struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + Detailed string `json:"detailed_review"` +} + +func parseRateResponse(response string) rateResult { + var result rateResult + //去除所有\n + response = strings.ReplaceAll(response, "\n", "") + //去除所有\t + response = strings.ReplaceAll(response, "\t", "") + // 去除response中的```json前缀和```后缀 + response = strings.TrimSpace(response) + response = strings.TrimPrefix(response, "```json") + response = strings.TrimSuffix(response, "```") + + + // 检查response是否是json格式 + if !strings.HasPrefix(response, "{") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有左括号", + } + } + if !strings.HasSuffix(response, "}") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有右括号", + } + } + + // 解析json + err := json.Unmarshal([]byte(response), &result) + if err != nil { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "反序列化失败", + } + } + + // 合并所有验证条件 + if result.Rate <= 0 || result.Rate > 100 || + result.Summary == "" || result.Detailed == "" { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "字段验证条件不满足", + } + } + + return result +} \ No newline at end of file diff --git a/.history/pkg/handler/rate_20250115163449.go b/.history/pkg/handler/rate_20250115163449.go new file mode 100644 index 0000000..2ba2b24 --- /dev/null +++ b/.history/pkg/handler/rate_20250115163449.go @@ -0,0 +1,53 @@ +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" +) + +type RateHandler struct { + geminiService *service.GeminiService +} + +func NewRateHandler(geminiService *service.GeminiService) *RateHandler { + return &RateHandler{ + geminiService: geminiService, + } +} + +type RateRequest struct { + Text string `json:"text" binding:"required"` +} + +type RateResponse struct { + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *RateHandler) HandleRate(c *gin.Context) { + var req RateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, RateResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Process with Gemini + result, err := h.geminiService.ProcessText(c.Request.Context(), req.Text) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, RateResponse{ + Success: true, + Result: result, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/rate_20250115163602.go b/.history/pkg/handler/rate_20250115163602.go new file mode 100644 index 0000000..2ba2b24 --- /dev/null +++ b/.history/pkg/handler/rate_20250115163602.go @@ -0,0 +1,53 @@ +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" +) + +type RateHandler struct { + geminiService *service.GeminiService +} + +func NewRateHandler(geminiService *service.GeminiService) *RateHandler { + return &RateHandler{ + geminiService: geminiService, + } +} + +type RateRequest struct { + Text string `json:"text" binding:"required"` +} + +type RateResponse struct { + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *RateHandler) HandleRate(c *gin.Context) { + var req RateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, RateResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Process with Gemini + result, err := h.geminiService.ProcessText(c.Request.Context(), req.Text) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, RateResponse{ + Success: true, + Result: result, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/upload_20250115161536.go b/.history/pkg/handler/upload_20250115161536.go new file mode 100644 index 0000000..af8cacb --- /dev/null +++ b/.history/pkg/handler/upload_20250115161536.go @@ -0,0 +1,162 @@ +// 上传文件到cloudflare R2 +package handler +import ( + "bytes" + "fmt" + "net/http" + "github.com/gin-gonic/gin" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "encoding/base64" + "io" + "strings" + "tencent_ocr/pkg/service" +) + +type UploadHandler struct { + accessKey string + secretKey string + bucket string + endpoint string + customDomain string + ocrService *OCRService + geminiService *service.GeminiService +} + +type MultiUploadResponse struct { + ImageURLs []string `json:"image_urls"` + Text string `json:"text"` + Success bool `json:"success"` +} + +func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse form"}) + return + } + + files := form.File["files"] + if len(files) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "No files uploaded"}) + return + } + + if len(files) > 5 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Maximum 5 files allowed"}) + return + } + + var imageURLs []string + var ocrTexts []string + + for _, fileHeader := range files { + if fileHeader.Size > 10<<20 { // 10MB + c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit of 10MB"}) + return + } + + file, err := fileHeader.Open() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open file"}) + return + } + defer file.Close() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read file"}) + return + } + + // Verify file type + contentType := http.DetectContentType(fileBytes) + if !isImage(contentType) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type. Only images are allowed"}) + return + } + + // Convert to base64 + base64Str := base64.StdEncoding.EncodeToString(fileBytes) + + // Process OCR + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "OCR processing failed"}) + return + } + ocrTexts = append(ocrTexts, ocrText) + + // Upload to R2 + imageURL, err := h.uploadToR2(fileBytes, fileHeader.Filename, contentType) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload file"}) + return + } + imageURLs = append(imageURLs, imageURL) + } + + // Process combined text with Gemini if multiple images + finalText := strings.Join(ocrTexts, "\n") + if len(ocrTexts) > 1 { + prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText + processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Text processing failed"}) + return + } + finalText = processedText + } + + c.JSON(http.StatusOK, MultiUploadResponse{ + ImageURLs: imageURLs, + Text: finalText, + Success: true, + }) +} + +// uploadToR2 上传文件到Cloudflare R2 +func (h *UploadHandler) uploadToR2(file []byte, fileName, contentType string) (string, error) { + // 创建S3会话 + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(h.endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(h.accessKey, h.secretKey, ""), + }) + if err != nil { + return "", fmt.Errorf("failed to create S3 session: %v", err) + } + + // 创建S3服务客户端 + svc := s3.New(sess) + + // 上传文件到R2 + _, err = svc.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(h.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(file), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), // 设置文件为公开可读 + }) + if err != nil { + return "", fmt.Errorf("failed to upload file to R2: %v", err) + } + + // 生成文件的URL + imageURL := fmt.Sprintf("https://%s/%s", h.customDomain, fileName) + return imageURL, nil +} + +// isImage 检查文件是否是图片 +func isImage(contentType string) bool { + allowedTypes := []string{"image/jpeg", "image/png", "image/gif", "image/bmp", "image/tiff", "image/webp"} + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} \ No newline at end of file diff --git a/.history/pkg/handler/upload_20250115163413.go b/.history/pkg/handler/upload_20250115163413.go new file mode 100644 index 0000000..5b4ab0d --- /dev/null +++ b/.history/pkg/handler/upload_20250115163413.go @@ -0,0 +1,35 @@ +// 上传文件到cloudflare R2 +package handler + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" +) + +type UploadHandler struct { + uploadService *service.UploadService + ocrService *service.OCRService + geminiService *service.GeminiService +} + +func NewUploadHandler( + uploadService *service.UploadService, + ocrService *service.OCRService, + geminiService *service.GeminiService, +) *UploadHandler { + return &UploadHandler{ + uploadService: uploadService, + ocrService: ocrService, + geminiService: geminiService, + } +} + +func (h *UploadHandler) HandleUpload(c *gin.Context) { + // Implementation here + c.JSON(http.StatusOK, gin.H{ + "success": true, + "message": "Upload endpoint is working", + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/upload_20250115163532.go b/.history/pkg/handler/upload_20250115163532.go new file mode 100644 index 0000000..17f7496 --- /dev/null +++ b/.history/pkg/handler/upload_20250115163532.go @@ -0,0 +1,154 @@ +// 上传文件到cloudflare R2 +package handler + +import ( + "encoding/base64" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" +) + +type UploadHandler struct { + uploadService *service.UploadService + ocrService *service.OCRService + geminiService *service.GeminiService +} + +func NewUploadHandler( + uploadService *service.UploadService, + ocrService *service.OCRService, + geminiService *service.GeminiService, +) *UploadHandler { + return &UploadHandler{ + uploadService: uploadService, + ocrService: ocrService, + geminiService: geminiService, + } +} + +type MultiUploadResponse struct { + ImageURLs []string `json:"image_urls"` + Text string `json:"text"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +func (h *UploadHandler) HandleUpload(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "Failed to parse form", + }) + return + } + + files := form.File["files"] + if len(files) == 0 { + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "No files uploaded", + }) + return + } + + if len(files) > 5 { + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "Maximum 5 files allowed", + }) + return + } + + var imageURLs []string + var ocrTexts []string + + for _, fileHeader := range files { + if fileHeader.Size > 10<<20 { // 10MB + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "File size exceeds the limit of 10MB", + }) + return + } + + file, err := fileHeader.Open() + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Failed to open file", + }) + return + } + defer file.Close() + + // Read file content for content type detection + fileBytes, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Failed to read file", + }) + return + } + + // Verify file type + contentType := http.DetectContentType(fileBytes) + if !h.uploadService.IsValidFileType(contentType) { + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "Invalid file type. Only images are allowed", + }) + return + } + + // Convert to base64 for OCR + base64Str := base64.StdEncoding.EncodeToString(fileBytes) + + // Process OCR + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str) + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "OCR processing failed", + }) + return + } + ocrTexts = append(ocrTexts, ocrText) + + // Upload to R2 + imageURL, err := h.uploadService.UploadFile(file, fileHeader.Filename, contentType) + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Failed to upload file", + }) + return + } + imageURLs = append(imageURLs, imageURL) + } + + // Process combined text with Gemini if multiple images + finalText := strings.Join(ocrTexts, "\n") + if len(ocrTexts) > 1 { + prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText + processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Text processing failed", + }) + return + } + finalText = processedText + } + + c.JSON(http.StatusOK, MultiUploadResponse{ + ImageURLs: imageURLs, + Text: finalText, + Success: true, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/upload_20250115163602.go b/.history/pkg/handler/upload_20250115163602.go new file mode 100644 index 0000000..17f7496 --- /dev/null +++ b/.history/pkg/handler/upload_20250115163602.go @@ -0,0 +1,154 @@ +// 上传文件到cloudflare R2 +package handler + +import ( + "encoding/base64" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" +) + +type UploadHandler struct { + uploadService *service.UploadService + ocrService *service.OCRService + geminiService *service.GeminiService +} + +func NewUploadHandler( + uploadService *service.UploadService, + ocrService *service.OCRService, + geminiService *service.GeminiService, +) *UploadHandler { + return &UploadHandler{ + uploadService: uploadService, + ocrService: ocrService, + geminiService: geminiService, + } +} + +type MultiUploadResponse struct { + ImageURLs []string `json:"image_urls"` + Text string `json:"text"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` +} + +func (h *UploadHandler) HandleUpload(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "Failed to parse form", + }) + return + } + + files := form.File["files"] + if len(files) == 0 { + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "No files uploaded", + }) + return + } + + if len(files) > 5 { + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "Maximum 5 files allowed", + }) + return + } + + var imageURLs []string + var ocrTexts []string + + for _, fileHeader := range files { + if fileHeader.Size > 10<<20 { // 10MB + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "File size exceeds the limit of 10MB", + }) + return + } + + file, err := fileHeader.Open() + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Failed to open file", + }) + return + } + defer file.Close() + + // Read file content for content type detection + fileBytes, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Failed to read file", + }) + return + } + + // Verify file type + contentType := http.DetectContentType(fileBytes) + if !h.uploadService.IsValidFileType(contentType) { + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "Invalid file type. Only images are allowed", + }) + return + } + + // Convert to base64 for OCR + base64Str := base64.StdEncoding.EncodeToString(fileBytes) + + // Process OCR + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str) + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "OCR processing failed", + }) + return + } + ocrTexts = append(ocrTexts, ocrText) + + // Upload to R2 + imageURL, err := h.uploadService.UploadFile(file, fileHeader.Filename, contentType) + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Failed to upload file", + }) + return + } + imageURLs = append(imageURLs, imageURL) + } + + // Process combined text with Gemini if multiple images + finalText := strings.Join(ocrTexts, "\n") + if len(ocrTexts) > 1 { + prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText + processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Text processing failed", + }) + return + } + finalText = processedText + } + + c.JSON(http.StatusOK, MultiUploadResponse{ + ImageURLs: imageURLs, + Text: finalText, + Success: true, + }) +} \ No newline at end of file diff --git a/.history/pkg/middleware/auth_20250115162815.go b/.history/pkg/middleware/auth_20250115162815.go new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/.history/pkg/middleware/auth_20250115162815.go @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/.history/pkg/middleware/auth_20250115162818.go b/.history/pkg/middleware/auth_20250115162818.go new file mode 100644 index 0000000..0e61b4f --- /dev/null +++ b/.history/pkg/middleware/auth_20250115162818.go @@ -0,0 +1,37 @@ +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +const ( + APIKeyHeader = "X-API-Key" +) + +// APIKeyAuth creates a middleware that validates the API key +func APIKeyAuth(validAPIKey string) gin.HandlerFunc { + return func(c *gin.Context) { + apiKey := c.GetHeader(APIKeyHeader) + if apiKey == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "error": "API key is required", + }) + c.Abort() + return + } + + if apiKey != validAPIKey { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "error": "Invalid API key", + }) + c.Abort() + return + } + + c.Next() + } +} \ No newline at end of file diff --git a/.history/pkg/middleware/auth_20250115162833.go b/.history/pkg/middleware/auth_20250115162833.go new file mode 100644 index 0000000..0e61b4f --- /dev/null +++ b/.history/pkg/middleware/auth_20250115162833.go @@ -0,0 +1,37 @@ +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +const ( + APIKeyHeader = "X-API-Key" +) + +// APIKeyAuth creates a middleware that validates the API key +func APIKeyAuth(validAPIKey string) gin.HandlerFunc { + return func(c *gin.Context) { + apiKey := c.GetHeader(APIKeyHeader) + if apiKey == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "error": "API key is required", + }) + c.Abort() + return + } + + if apiKey != validAPIKey { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "error": "Invalid API key", + }) + c.Abort() + return + } + + c.Next() + } +} \ No newline at end of file diff --git a/.history/pkg/middleware/ratelimit_20250115162829.go b/.history/pkg/middleware/ratelimit_20250115162829.go new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/.history/pkg/middleware/ratelimit_20250115162829.go @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/.history/pkg/middleware/ratelimit_20250115162834.go b/.history/pkg/middleware/ratelimit_20250115162834.go new file mode 100644 index 0000000..64e6fb4 --- /dev/null +++ b/.history/pkg/middleware/ratelimit_20250115162834.go @@ -0,0 +1,76 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +type RateLimiter struct { + limiters map[string]*rate.Limiter + mu sync.RWMutex + r rate.Limit + b int +} + +func NewRateLimiter(r rate.Limit, b int) *RateLimiter { + return &RateLimiter{ + limiters: make(map[string]*rate.Limiter), + r: r, + b: b, + } +} + +func (rl *RateLimiter) getLimiter(key string) *rate.Limiter { + rl.mu.Lock() + defer rl.mu.Unlock() + + limiter, exists := rl.limiters[key] + if !exists { + limiter = rate.NewLimiter(rl.r, rl.b) + rl.limiters[key] = limiter + } + + return limiter +} + +// RateLimit creates a middleware that limits request rates per API key +func RateLimit(r rate.Limit, burst int) gin.HandlerFunc { + rl := NewRateLimiter(r, burst) + + return func(c *gin.Context) { + key := c.GetHeader(APIKeyHeader) + if key == "" { + key = c.ClientIP() + } + + limiter := rl.getLimiter(key) + if !limiter.Allow() { + c.JSON(http.StatusTooManyRequests, gin.H{ + "success": false, + "error": "Rate limit exceeded", + }) + c.Abort() + return + } + + c.Next() + } +} + +// CleanupTask periodically removes unused limiters +func (rl *RateLimiter) CleanupTask(interval time.Duration) { + ticker := time.NewTicker(interval) + go func() { + for range ticker.C { + rl.mu.Lock() + for key := range rl.limiters { + delete(rl.limiters, key) + } + rl.mu.Unlock() + } + }() +} \ No newline at end of file diff --git a/.history/pkg/middleware/ratelimit_20250115162845.go b/.history/pkg/middleware/ratelimit_20250115162845.go new file mode 100644 index 0000000..64e6fb4 --- /dev/null +++ b/.history/pkg/middleware/ratelimit_20250115162845.go @@ -0,0 +1,76 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +type RateLimiter struct { + limiters map[string]*rate.Limiter + mu sync.RWMutex + r rate.Limit + b int +} + +func NewRateLimiter(r rate.Limit, b int) *RateLimiter { + return &RateLimiter{ + limiters: make(map[string]*rate.Limiter), + r: r, + b: b, + } +} + +func (rl *RateLimiter) getLimiter(key string) *rate.Limiter { + rl.mu.Lock() + defer rl.mu.Unlock() + + limiter, exists := rl.limiters[key] + if !exists { + limiter = rate.NewLimiter(rl.r, rl.b) + rl.limiters[key] = limiter + } + + return limiter +} + +// RateLimit creates a middleware that limits request rates per API key +func RateLimit(r rate.Limit, burst int) gin.HandlerFunc { + rl := NewRateLimiter(r, burst) + + return func(c *gin.Context) { + key := c.GetHeader(APIKeyHeader) + if key == "" { + key = c.ClientIP() + } + + limiter := rl.getLimiter(key) + if !limiter.Allow() { + c.JSON(http.StatusTooManyRequests, gin.H{ + "success": false, + "error": "Rate limit exceeded", + }) + c.Abort() + return + } + + c.Next() + } +} + +// CleanupTask periodically removes unused limiters +func (rl *RateLimiter) CleanupTask(interval time.Duration) { + ticker := time.NewTicker(interval) + go func() { + for range ticker.C { + rl.mu.Lock() + for key := range rl.limiters { + delete(rl.limiters, key) + } + rl.mu.Unlock() + } + }() +} \ No newline at end of file diff --git a/.history/pkg/service/gemini_20250115163547.go b/.history/pkg/service/gemini_20250115163547.go new file mode 100644 index 0000000..1b775da --- /dev/null +++ b/.history/pkg/service/gemini_20250115163547.go @@ -0,0 +1,62 @@ +package service + +import ( + "context" + "sync" + + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/option" + "tencent_ocr/pkg/errors" +) + +type GeminiService struct { + client *genai.Client + model *genai.GenerativeModel + apiKey string + mu sync.RWMutex +} + +func NewGeminiService(apiKey string) (*GeminiService, error) { + ctx := context.Background() + client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) + if err != nil { + return nil, errors.Wrap(err, "failed to create Gemini client") + } + + return &GeminiService{ + client: client, + model: client.GenerativeModel("gemini-pro"), + apiKey: apiKey, + }, nil +} + +func (s *GeminiService) ProcessText(ctx context.Context, text string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + resp, err := s.model.GenerateContent(ctx, genai.Text(text)) + if err != nil { + return "", errors.Wrap(err, "failed to generate content") + } + + if len(resp.Candidates) == 0 || len(resp.Candidates[0].Content.Parts) == 0 { + return "", errors.New("no response from Gemini") + } + + if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { + return string(textPart), nil + } + + return "", errors.New("invalid response format from Gemini") +} + +// Close implements graceful shutdown +func (s *GeminiService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.client != nil { + s.client.Close() + } + return nil +} \ No newline at end of file diff --git a/.history/pkg/service/gemini_20250115163602.go b/.history/pkg/service/gemini_20250115163602.go new file mode 100644 index 0000000..1b775da --- /dev/null +++ b/.history/pkg/service/gemini_20250115163602.go @@ -0,0 +1,62 @@ +package service + +import ( + "context" + "sync" + + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/option" + "tencent_ocr/pkg/errors" +) + +type GeminiService struct { + client *genai.Client + model *genai.GenerativeModel + apiKey string + mu sync.RWMutex +} + +func NewGeminiService(apiKey string) (*GeminiService, error) { + ctx := context.Background() + client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) + if err != nil { + return nil, errors.Wrap(err, "failed to create Gemini client") + } + + return &GeminiService{ + client: client, + model: client.GenerativeModel("gemini-pro"), + apiKey: apiKey, + }, nil +} + +func (s *GeminiService) ProcessText(ctx context.Context, text string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + resp, err := s.model.GenerateContent(ctx, genai.Text(text)) + if err != nil { + return "", errors.Wrap(err, "failed to generate content") + } + + if len(resp.Candidates) == 0 || len(resp.Candidates[0].Content.Parts) == 0 { + return "", errors.New("no response from Gemini") + } + + if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { + return string(textPart), nil + } + + return "", errors.New("invalid response format from Gemini") +} + +// Close implements graceful shutdown +func (s *GeminiService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + if s.client != nil { + s.client.Close() + } + return nil +} \ No newline at end of file diff --git a/.history/pkg/service/ocr_20250115162751.go b/.history/pkg/service/ocr_20250115162751.go new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/.history/pkg/service/ocr_20250115162751.go @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/.history/pkg/service/ocr_20250115162754.go b/.history/pkg/service/ocr_20250115162754.go new file mode 100644 index 0000000..489b68a --- /dev/null +++ b/.history/pkg/service/ocr_20250115162754.go @@ -0,0 +1,71 @@ +package service + +import ( + "context" + "sync" + + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "tencent_ocr/pkg/errors" +) + +type OCRService struct { + client *ocr.Client + tencentSecretID string + tencentSecretKey string + mu sync.RWMutex +} + +func NewOCRService(tencentSecretID, tencentSecretKey string) (*OCRService, error) { + credential := common.NewCredential(tencentSecretID, tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return nil, errors.Wrap(err, "failed to create Tencent Cloud OCR client") + } + + return &OCRService{ + client: client, + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + }, nil +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + if imageBase64 == "" { + return "", errors.NewClientError("image data is required") + } + + s.mu.RLock() + defer s.mu.RUnlock() + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := s.client.GeneralHandwritingOCRWithContext(ctx, request) + if err != nil { + return "", errors.Wrap(err, "failed to perform OCR") + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +// Close implements graceful shutdown +func (s *OCRService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Add any cleanup logic here if needed + return nil +} \ No newline at end of file diff --git a/.history/pkg/service/ocr_20250115162813.go b/.history/pkg/service/ocr_20250115162813.go new file mode 100644 index 0000000..489b68a --- /dev/null +++ b/.history/pkg/service/ocr_20250115162813.go @@ -0,0 +1,71 @@ +package service + +import ( + "context" + "sync" + + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "tencent_ocr/pkg/errors" +) + +type OCRService struct { + client *ocr.Client + tencentSecretID string + tencentSecretKey string + mu sync.RWMutex +} + +func NewOCRService(tencentSecretID, tencentSecretKey string) (*OCRService, error) { + credential := common.NewCredential(tencentSecretID, tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return nil, errors.Wrap(err, "failed to create Tencent Cloud OCR client") + } + + return &OCRService{ + client: client, + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + }, nil +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + if imageBase64 == "" { + return "", errors.NewClientError("image data is required") + } + + s.mu.RLock() + defer s.mu.RUnlock() + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := s.client.GeneralHandwritingOCRWithContext(ctx, request) + if err != nil { + return "", errors.Wrap(err, "failed to perform OCR") + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +// Close implements graceful shutdown +func (s *OCRService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Add any cleanup logic here if needed + return nil +} \ No newline at end of file diff --git a/.history/pkg/service/upload_20250115163504.go b/.history/pkg/service/upload_20250115163504.go new file mode 100644 index 0000000..0519ecb --- /dev/null +++ b/.history/pkg/service/upload_20250115163504.go @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/.history/pkg/service/upload_20250115163508.go b/.history/pkg/service/upload_20250115163508.go new file mode 100644 index 0000000..b62cc0d --- /dev/null +++ b/.history/pkg/service/upload_20250115163508.go @@ -0,0 +1,92 @@ +package service + +import ( + "bytes" + "fmt" + "io" + "net/http" + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "tencent_ocr/pkg/errors" +) + +type UploadService struct { + s3Client *s3.S3 + bucket string + customDomain string + mu sync.RWMutex +} + +func NewUploadService(accessKey, secretKey, bucket, endpoint, customDomain string) (*UploadService, error) { + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(accessKey, secretKey, ""), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to create S3 session") + } + + return &UploadService{ + s3Client: s3.New(sess), + bucket: bucket, + customDomain: customDomain, + }, nil +} + +func (s *UploadService) UploadFile(file io.Reader, fileName, contentType string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + return "", errors.Wrap(err, "failed to read file") + } + + // Upload file to R2 + _, err = s.s3Client.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(fileBytes), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), + }) + if err != nil { + return "", errors.Wrap(err, "failed to upload file to R2") + } + + // Generate file URL + imageURL := fmt.Sprintf("https://%s/%s", s.customDomain, fileName) + return imageURL, nil +} + +func (s *UploadService) IsValidFileType(contentType string) bool { + allowedTypes := []string{ + "image/jpeg", + "image/png", + "image/gif", + "image/bmp", + "image/tiff", + "image/webp", + } + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} + +// Close implements graceful shutdown +func (s *UploadService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Add any cleanup logic here if needed + return nil +} \ No newline at end of file diff --git a/.history/pkg/service/upload_20250115163602.go b/.history/pkg/service/upload_20250115163602.go new file mode 100644 index 0000000..b62cc0d --- /dev/null +++ b/.history/pkg/service/upload_20250115163602.go @@ -0,0 +1,92 @@ +package service + +import ( + "bytes" + "fmt" + "io" + "net/http" + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "tencent_ocr/pkg/errors" +) + +type UploadService struct { + s3Client *s3.S3 + bucket string + customDomain string + mu sync.RWMutex +} + +func NewUploadService(accessKey, secretKey, bucket, endpoint, customDomain string) (*UploadService, error) { + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(accessKey, secretKey, ""), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to create S3 session") + } + + return &UploadService{ + s3Client: s3.New(sess), + bucket: bucket, + customDomain: customDomain, + }, nil +} + +func (s *UploadService) UploadFile(file io.Reader, fileName, contentType string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + return "", errors.Wrap(err, "failed to read file") + } + + // Upload file to R2 + _, err = s.s3Client.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(fileBytes), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), + }) + if err != nil { + return "", errors.Wrap(err, "failed to upload file to R2") + } + + // Generate file URL + imageURL := fmt.Sprintf("https://%s/%s", s.customDomain, fileName) + return imageURL, nil +} + +func (s *UploadService) IsValidFileType(contentType string) bool { + allowedTypes := []string{ + "image/jpeg", + "image/png", + "image/gif", + "image/bmp", + "image/tiff", + "image/webp", + } + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} + +// Close implements graceful shutdown +func (s *UploadService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Add any cleanup logic here if needed + return nil +} \ No newline at end of file diff --git a/.history/pkg/service/upload_20250115163757.go b/.history/pkg/service/upload_20250115163757.go new file mode 100644 index 0000000..8e5bd8b --- /dev/null +++ b/.history/pkg/service/upload_20250115163757.go @@ -0,0 +1,92 @@ +package service + +import ( + "bytes" + "fmt" + "io" + + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "tencent_ocr/pkg/errors" +) + +type UploadService struct { + s3Client *s3.S3 + bucket string + customDomain string + mu sync.RWMutex +} + +func NewUploadService(accessKey, secretKey, bucket, endpoint, customDomain string) (*UploadService, error) { + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(accessKey, secretKey, ""), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to create S3 session") + } + + return &UploadService{ + s3Client: s3.New(sess), + bucket: bucket, + customDomain: customDomain, + }, nil +} + +func (s *UploadService) UploadFile(file io.Reader, fileName, contentType string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + return "", errors.Wrap(err, "failed to read file") + } + + // Upload file to R2 + _, err = s.s3Client.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(fileBytes), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), + }) + if err != nil { + return "", errors.Wrap(err, "failed to upload file to R2") + } + + // Generate file URL + imageURL := fmt.Sprintf("https://%s/%s", s.customDomain, fileName) + return imageURL, nil +} + +func (s *UploadService) IsValidFileType(contentType string) bool { + allowedTypes := []string{ + "image/jpeg", + "image/png", + "image/gif", + "image/bmp", + "image/tiff", + "image/webp", + } + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} + +// Close implements graceful shutdown +func (s *UploadService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Add any cleanup logic here if needed + return nil +} \ No newline at end of file diff --git a/cmd/main.go b/cmd/main.go index 5460f36..a0f50fc 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -1,11 +1,21 @@ package main import ( + "context" "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" "github.com/gin-gonic/gin" - "tencenthw/pkg/config" - "tencenthw/pkg/handler" + "golang.org/x/time/rate" + + "tencent_ocr/pkg/config" + "tencent_ocr/pkg/handler" + "tencent_ocr/pkg/middleware" + "tencent_ocr/pkg/service" ) func main() { @@ -14,53 +24,72 @@ func main() { if err != nil { log.Fatalf("Failed to load configuration: %v", err) } + // Initialize services + ocrService, err := service.NewOCRService(cfg.TencentSecretID, cfg.TencentSecretKey) + if err != nil { + log.Fatalf("Failed to initialize OCR service: %v", err) + } + defer ocrService.Close() + geminiService, err := service.NewGeminiService(cfg.GeminiAPIKey) if err != nil { - log.Fatal(err) + log.Fatalf("Failed to initialize Gemini service: %v", err) } defer geminiService.Close() - ocrService := handler.NewOCRService( - cfg.TencentSecretID, - cfg.TencentSecretKey, - geminiService, - ) - - // Initialize handlers - ocrHandler := handler.NewOCRHandler( - cfg.TencentSecretID, - cfg.TencentSecretKey, - cfg.GeminiAPIKey, - cfg.APIKey, - ) - - rateHandler := handler.NewRateHandler( - cfg.GeminiAPIKey, - cfg.APIKey, - ) - - uploadHandler := handler.NewUploadHandler( + uploadService, err := service.NewUploadService( cfg.R2AccessKey, cfg.R2SecretKey, cfg.R2Bucket, cfg.R2Endpoint, cfg.R2CustomDomain, - ocrService, - geminiService, ) + if err != nil { + log.Fatalf("Failed to initialize upload service: %v", err) + } + defer uploadService.Close() + + // Initialize handlers + ocrHandler := handler.NewOCRHandler(ocrService, geminiService) + rateHandler := handler.NewRateHandler(geminiService) + uploadHandler := handler.NewUploadHandler(uploadService, ocrService, geminiService) // Setup Gin router r := gin.Default() + // Add middleware + r.Use(middleware.APIKeyAuth(cfg.APIKey)) + r.Use(middleware.RateLimit(rate.Limit(10), 20)) // 10 requests per second with burst of 20 + // Register routes r.POST("/ocr", ocrHandler.HandleOCR) r.POST("/rate", rateHandler.HandleRate) - // upload file to server r.POST("/upload", uploadHandler.HandleUpload) - // Start server - if err := r.Run("localhost:8080"); err != nil { + // Create server with graceful shutdown + srv := &http.Server{ + Addr: "localhost:8080", + Handler: r, + } + + // Graceful shutdown + go func() { + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + log.Println("Shutting down server...") + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + log.Fatal("Server forced to shutdown:", err) + } + }() + + log.Println("Server starting on localhost:8080") + if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { log.Fatalf("Failed to start server: %v", err) } } \ No newline at end of file diff --git a/go.mod b/go.mod index 5f6188b..4eed5ea 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module tencenthw +module tencent_ocr go 1.23.4 @@ -9,6 +9,7 @@ require ( github.com/joho/godotenv v1.5.1 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common v1.0.1081 github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr v1.0.1081 + golang.org/x/time v0.9.0 google.golang.org/api v0.216.0 ) @@ -58,7 +59,6 @@ require ( golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect - golang.org/x/time v0.9.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241209162323-e6fa225c2576 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250102185135-69823020774d // indirect google.golang.org/grpc v1.69.2 // indirect diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go new file mode 100644 index 0000000..2190bd5 --- /dev/null +++ b/pkg/errors/errors.go @@ -0,0 +1,82 @@ +package errors + +import ( + "fmt" +) + +type Error struct { + Message string + Cause error + Type ErrorType +} + +type ErrorType int + +const ( + ErrorTypeUnknown ErrorType = iota + ErrorTypeClient + ErrorTypeServer +) + +func (e *Error) Error() string { + if e.Cause != nil { + return fmt.Sprintf("%s: %v", e.Message, e.Cause) + } + return e.Message +} + +func (e *Error) Unwrap() error { + return e.Cause +} + +func New(message string) error { + return &Error{ + Message: message, + Type: ErrorTypeUnknown, + } +} + +func Wrap(err error, message string) error { + if err == nil { + return nil + } + return &Error{ + Message: message, + Cause: err, + Type: ErrorTypeUnknown, + } +} + +func NewClientError(message string) error { + return &Error{ + Message: message, + Type: ErrorTypeClient, + } +} + +func NewServerError(message string) error { + return &Error{ + Message: message, + Type: ErrorTypeServer, + } +} + +func IsClientError(err error) bool { + if err == nil { + return false + } + if e, ok := err.(*Error); ok { + return e.Type == ErrorTypeClient + } + return false +} + +func IsServerError(err error) bool { + if err == nil { + return false + } + if e, ok := err.(*Error); ok { + return e.Type == ErrorTypeServer + } + return false +} \ No newline at end of file diff --git a/pkg/handler/ocr.go b/pkg/handler/ocr.go index 6ab82ac..ef114bb 100644 --- a/pkg/handler/ocr.go +++ b/pkg/handler/ocr.go @@ -1,91 +1,44 @@ package handler import ( - "context" - "encoding/base64" "net/http" - "strings" "github.com/gin-gonic/gin" - "github.com/google/generative-ai-go/genai" - "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" - "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" - ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" - "google.golang.org/api/option" - "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" + "tencent_ocr/pkg/service" + "tencent_ocr/pkg/errors" ) -type OCRService struct { - tencentSecretID string - tencentSecretKey string - geminiService *service.GeminiService +type OCRHandler struct { + ocrService *service.OCRService + geminiService *service.GeminiService } -func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { - return &OCRService{ - tencentSecretID: tencentSecretID, - tencentSecretKey: tencentSecretKey, - geminiService: geminiService, +func NewOCRHandler(ocrService *service.OCRService, geminiService *service.GeminiService) *OCRHandler { + return &OCRHandler{ + ocrService: ocrService, + geminiService: geminiService, } } -func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { - // Initialize Tencent Cloud client - credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) - cpf := profile.NewClientProfile() - cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" - client, err := ocr.NewClient(credential, "", cpf) - if err != nil { - return "", err - } - - // Create OCR request - request := ocr.NewGeneralHandwritingOCRRequest() - request.ImageBase64 = common.StringPtr(imageBase64) - - // Perform OCR - response, err := client.GeneralHandwritingOCR(request) - if err != nil { - return "", err - } - - // Extract text from OCR response - var ocrText string - for _, textDetection := range response.Response.TextDetections { - ocrText += *textDetection.DetectedText + "\n" - } - - return ocrText, nil -} - type OCRRequest struct { ImageBase64 string `json:"image_base64"` ImageURL string `json:"image_url"` Scene string `json:"scene"` - APIKey string `json:"apikey" binding:"required"` } type OCRResponse struct { OriginalText string `json:"original_text"` - Result string `json:"result"` - Success bool `json:"success"` + Result string `json:"result"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` } -func (h *OCRService) HandleOCR(c *gin.Context) { +func (h *OCRHandler) HandleOCR(c *gin.Context) { var req OCRRequest if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, OCRResponse{ Success: false, - Result: "Invalid request format", - }) - return - } - - // Validate API key - if req.APIKey != h.geminiService.APIKey { - c.JSON(http.StatusUnauthorized, OCRResponse{ - Success: false, - Result: "Invalid API key", + Error: "Invalid request format", }) return } @@ -94,17 +47,21 @@ func (h *OCRService) HandleOCR(c *gin.Context) { if req.ImageURL == "" && req.ImageBase64 == "" { c.JSON(http.StatusBadRequest, OCRResponse{ Success: false, - Result: "Either image_url or image_base64 must be provided", + Error: "Either image_url or image_base64 must be provided", }) return } // Process image - ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), req.ImageBase64) if err != nil { - c.JSON(http.StatusInternalServerError, OCRResponse{ + status := http.StatusInternalServerError + if errors.IsClientError(err) { + status = http.StatusBadRequest + } + c.JSON(status, OCRResponse{ Success: false, - Result: "OCR processing failed", + Error: err.Error(), }) return } @@ -114,14 +71,14 @@ func (h *OCRService) HandleOCR(c *gin.Context) { if err != nil { c.JSON(http.StatusInternalServerError, OCRResponse{ Success: false, - Result: "Text processing failed", + Error: "Text processing failed: " + err.Error(), }) return } c.JSON(http.StatusOK, OCRResponse{ - Success: true, + Success: true, OriginalText: ocrText, - Result: processedText, + Result: processedText, }) } \ No newline at end of file diff --git a/pkg/handler/rate.go b/pkg/handler/rate.go index bb3e3e8..2ba2b24 100644 --- a/pkg/handler/rate.go +++ b/pkg/handler/rate.go @@ -2,38 +2,28 @@ package handler import ( "net/http" + "github.com/gin-gonic/gin" - "github.com/google/generative-ai-go/genai" - "google.golang.org/api/option" - "encoding/json" - "strings" - "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" + "tencent_ocr/pkg/service" ) type RateHandler struct { - geminiAPIKey string - apiKey string + geminiService *service.GeminiService +} + +func NewRateHandler(geminiService *service.GeminiService) *RateHandler { + return &RateHandler{ + geminiService: geminiService, + } } type RateRequest struct { - Content string `json:"content" binding:"required"` - Criteria string `json:"criteria"` - WritingRequirement string `json:"writing_requirement"` - APIKey string `json:"apikey" binding:"required"` + Text string `json:"text" binding:"required"` } type RateResponse struct { - Rate int `json:"rate"` - Summary string `json:"summary"` - DetailedReview string `json:"detailed_review"` - Success bool `json:"success"` -} - -func NewRateHandler(geminiAPIKey, apiKey string) *RateHandler { - return &RateHandler{ - geminiAPIKey: geminiAPIKey, - apiKey: apiKey, - } + Result string `json:"result"` + Success bool `json:"success"` } func (h *RateHandler) HandleRate(c *gin.Context) { @@ -41,137 +31,23 @@ func (h *RateHandler) HandleRate(c *gin.Context) { if err := c.ShouldBindJSON(&req); err != nil { c.JSON(http.StatusBadRequest, RateResponse{ Success: false, + Result: "Invalid request format", }) return } - // Validate API key - if req.APIKey != h.apiKey { - c.JSON(http.StatusUnauthorized, RateResponse{ - Success: false, - }) - return - } - - // Initialize Gemini client - ctx := c.Request.Context() - client, err := genai.NewClient(ctx, option.WithAPIKey(h.geminiAPIKey)) - if err != nil { - c.JSON(http.StatusInternalServerError, RateResponse{ - Success: false, - }) - return - } - defer client.Close() - - // Prepare criteria - criteria := req.Criteria - if criteria == "" { - criteria = `你是一名语文老师。你正在给学生的作文打分。根据以下中考作文评分标准,给作文打分。 -## 评分总分值:100分。 -### 88-100分 符合题意;写作目的和对象明确;思考充分,立意深刻,感情真挚;选材精当,内容充实;中心突出,条理清晰;表达准确,语言流畅。 -### 75-87分 符合题意;写作目的和对象较明确;思考较充分,立意清楚,感情真实;选材合理,内容具体;中心明确,有一定条理;表达较准确,语言通畅。 -### 60-74分 符合题意;写作目的和对象较模糊;有一定思考,感情真实;有一定内容;结构基本完整;语言尚通畅。 -### 60分以下 不符合题意;缺乏写作目的和对象;基本没有思考,感情虚假;内容空洞;结构混乱;不成篇。` - } - writing_requirement := req.WritingRequirement - if writing_requirement == "" { - writing_requirement = "写一篇不少于600字的作文,体裁不限。" - } - - // 规定输出格式是json,包含rate, summary, detailed_review,放入prompt的最后 - format := `请按照以下JSON格式输出: -{ - "rate": 分数, // 最多100分制的分数,int类型 - "summary": "总体评价", // 100字以内的总体评价,string类型 - "detailed_review": "详细点评" // 300字以内的详细点评,包含优点和建议,string类型 -}` - // Prepare prompt - prompt := "作文要求:\n" + writing_requirement + "\n\n" + "评分标准:\n" + criteria + format + "\n\n" + "\n\n作文内容:\n" + req.Content - - // Generate content - model := client.GenerativeModel("gemini-2.0-flash-exp") - resp, err := model.GenerateContent(ctx, genai.Text(prompt)) + // Process with Gemini + result, err := h.geminiService.ProcessText(c.Request.Context(), req.Text) if err != nil { c.JSON(http.StatusInternalServerError, RateResponse{ Success: false, + Result: "Text processing failed", }) return } - if len(resp.Candidates) > 0 && len(resp.Candidates[0].Content.Parts) > 0 { - if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { - // Parse the response to extract rate, summary, and detailed review - result := parseRateResponse(string(textPart)) - - c.JSON(http.StatusOK, RateResponse{ - Rate: result.Rate, - Summary: result.Summary, - DetailedReview: result.Detailed, - Success: true, - }) - return - } - } - - c.JSON(http.StatusInternalServerError, RateResponse{ - Success: false, + c.JSON(http.StatusOK, RateResponse{ + Success: true, + Result: result, }) -} - -type rateResult struct { - Rate int `json:"rate"` - Summary string `json:"summary"` - Detailed string `json:"detailed_review"` -} - -func parseRateResponse(response string) rateResult { - var result rateResult - //去除所有\n - response = strings.ReplaceAll(response, "\n", "") - //去除所有\t - response = strings.ReplaceAll(response, "\t", "") - // 去除response中的```json前缀和```后缀 - response = strings.TrimSpace(response) - response = strings.TrimPrefix(response, "```json") - response = strings.TrimSuffix(response, "```") - - - // 检查response是否是json格式 - if !strings.HasPrefix(response, "{") { - return rateResult{ - Rate: 0, - Summary: "解析失败", - Detailed: "没有左括号", - } - } - if !strings.HasSuffix(response, "}") { - return rateResult{ - Rate: 0, - Summary: "解析失败", - Detailed: "没有右括号", - } - } - - // 解析json - err := json.Unmarshal([]byte(response), &result) - if err != nil { - return rateResult{ - Rate: 0, - Summary: "解析失败", - Detailed: "反序列化失败", - } - } - - // 合并所有验证条件 - if result.Rate <= 0 || result.Rate > 100 || - result.Summary == "" || result.Detailed == "" { - return rateResult{ - Rate: 0, - Summary: "解析失败", - Detailed: "字段验证条件不满足", - } - } - - return result } \ No newline at end of file diff --git a/pkg/handler/upload.go b/pkg/handler/upload.go index cf4fa4a..17f7496 100644 --- a/pkg/handler/upload.go +++ b/pkg/handler/upload.go @@ -1,51 +1,65 @@ // 上传文件到cloudflare R2 package handler + import ( - "bytes" - "fmt" - "net/http" - "github.com/gin-gonic/gin" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/s3" "encoding/base64" "io" + "net/http" "strings" - "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" + + "github.com/gin-gonic/gin" + "tencent_ocr/pkg/service" ) type UploadHandler struct { - accessKey string - secretKey string - bucket string - endpoint string - customDomain string - ocrService *OCRService + uploadService *service.UploadService + ocrService *service.OCRService geminiService *service.GeminiService } +func NewUploadHandler( + uploadService *service.UploadService, + ocrService *service.OCRService, + geminiService *service.GeminiService, +) *UploadHandler { + return &UploadHandler{ + uploadService: uploadService, + ocrService: ocrService, + geminiService: geminiService, + } +} + type MultiUploadResponse struct { ImageURLs []string `json:"image_urls"` - Text string `json:"text"` - Success bool `json:"success"` + Text string `json:"text"` + Success bool `json:"success"` + Error string `json:"error,omitempty"` } -func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { +func (h *UploadHandler) HandleUpload(c *gin.Context) { form, err := c.MultipartForm() if err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse form"}) + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "Failed to parse form", + }) return } files := form.File["files"] if len(files) == 0 { - c.JSON(http.StatusBadRequest, gin.H{"error": "No files uploaded"}) + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "No files uploaded", + }) return } if len(files) > 5 { - c.JSON(http.StatusBadRequest, gin.H{"error": "Maximum 5 files allowed"}) + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "Maximum 5 files allowed", + }) return } @@ -54,46 +68,64 @@ func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { for _, fileHeader := range files { if fileHeader.Size > 10<<20 { // 10MB - c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit of 10MB"}) + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "File size exceeds the limit of 10MB", + }) return } file, err := fileHeader.Open() if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open file"}) + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Failed to open file", + }) return } defer file.Close() - // Read file content + // Read file content for content type detection fileBytes, err := io.ReadAll(file) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read file"}) + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Failed to read file", + }) return } // Verify file type contentType := http.DetectContentType(fileBytes) - if !isImage(contentType) { - c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type. Only images are allowed"}) + if !h.uploadService.IsValidFileType(contentType) { + c.JSON(http.StatusBadRequest, MultiUploadResponse{ + Success: false, + Error: "Invalid file type. Only images are allowed", + }) return } - // Convert to base64 + // Convert to base64 for OCR base64Str := base64.StdEncoding.EncodeToString(fileBytes) // Process OCR ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "OCR processing failed"}) + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "OCR processing failed", + }) return } ocrTexts = append(ocrTexts, ocrText) // Upload to R2 - imageURL, err := h.uploadToR2(fileBytes, fileHeader.Filename, contentType) + imageURL, err := h.uploadService.UploadFile(file, fileHeader.Filename, contentType) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload file"}) + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Failed to upload file", + }) return } imageURLs = append(imageURLs, imageURL) @@ -105,7 +137,10 @@ func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt) if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "Text processing failed"}) + c.JSON(http.StatusInternalServerError, MultiUploadResponse{ + Success: false, + Error: "Text processing failed", + }) return } finalText = processedText @@ -116,47 +151,4 @@ func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { Text: finalText, Success: true, }) -} - -// uploadToR2 上传文件到Cloudflare R2 -func (h *UploadHandler) uploadToR2(file []byte, fileName, contentType string) (string, error) { - // 创建S3会话 - sess, err := session.NewSession(&aws.Config{ - Endpoint: aws.String(h.endpoint), - Region: aws.String("auto"), - Credentials: credentials.NewStaticCredentials(h.accessKey, h.secretKey, ""), - }) - if err != nil { - return "", fmt.Errorf("failed to create S3 session: %v", err) - } - - // 创建S3服务客户端 - svc := s3.New(sess) - - // 上传文件到R2 - _, err = svc.PutObject(&s3.PutObjectInput{ - Bucket: aws.String(h.bucket), - Key: aws.String(fileName), - Body: bytes.NewReader(file), - ContentType: aws.String(contentType), - ACL: aws.String("public-read"), // 设置文件为公开可读 - }) - if err != nil { - return "", fmt.Errorf("failed to upload file to R2: %v", err) - } - - // 生成文件的URL - imageURL := fmt.Sprintf("https://%s/%s", h.customDomain, fileName) - return imageURL, nil -} - -// isImage 检查文件是否是图片 -func isImage(contentType string) bool { - allowedTypes := []string{"image/jpeg", "image/png", "image/gif", "image/bmp", "image/tiff", "image/webp"} - for _, t := range allowedTypes { - if contentType == t { - return true - } - } - return false } \ No newline at end of file diff --git a/pkg/middleware/auth.go b/pkg/middleware/auth.go new file mode 100644 index 0000000..0e61b4f --- /dev/null +++ b/pkg/middleware/auth.go @@ -0,0 +1,37 @@ +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +const ( + APIKeyHeader = "X-API-Key" +) + +// APIKeyAuth creates a middleware that validates the API key +func APIKeyAuth(validAPIKey string) gin.HandlerFunc { + return func(c *gin.Context) { + apiKey := c.GetHeader(APIKeyHeader) + if apiKey == "" { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "error": "API key is required", + }) + c.Abort() + return + } + + if apiKey != validAPIKey { + c.JSON(http.StatusUnauthorized, gin.H{ + "success": false, + "error": "Invalid API key", + }) + c.Abort() + return + } + + c.Next() + } +} \ No newline at end of file diff --git a/pkg/middleware/ratelimit.go b/pkg/middleware/ratelimit.go new file mode 100644 index 0000000..64e6fb4 --- /dev/null +++ b/pkg/middleware/ratelimit.go @@ -0,0 +1,76 @@ +package middleware + +import ( + "net/http" + "sync" + "time" + + "github.com/gin-gonic/gin" + "golang.org/x/time/rate" +) + +type RateLimiter struct { + limiters map[string]*rate.Limiter + mu sync.RWMutex + r rate.Limit + b int +} + +func NewRateLimiter(r rate.Limit, b int) *RateLimiter { + return &RateLimiter{ + limiters: make(map[string]*rate.Limiter), + r: r, + b: b, + } +} + +func (rl *RateLimiter) getLimiter(key string) *rate.Limiter { + rl.mu.Lock() + defer rl.mu.Unlock() + + limiter, exists := rl.limiters[key] + if !exists { + limiter = rate.NewLimiter(rl.r, rl.b) + rl.limiters[key] = limiter + } + + return limiter +} + +// RateLimit creates a middleware that limits request rates per API key +func RateLimit(r rate.Limit, burst int) gin.HandlerFunc { + rl := NewRateLimiter(r, burst) + + return func(c *gin.Context) { + key := c.GetHeader(APIKeyHeader) + if key == "" { + key = c.ClientIP() + } + + limiter := rl.getLimiter(key) + if !limiter.Allow() { + c.JSON(http.StatusTooManyRequests, gin.H{ + "success": false, + "error": "Rate limit exceeded", + }) + c.Abort() + return + } + + c.Next() + } +} + +// CleanupTask periodically removes unused limiters +func (rl *RateLimiter) CleanupTask(interval time.Duration) { + ticker := time.NewTicker(interval) + go func() { + for range ticker.C { + rl.mu.Lock() + for key := range rl.limiters { + delete(rl.limiters, key) + } + rl.mu.Unlock() + } + }() +} \ No newline at end of file diff --git a/pkg/service/gemini.go b/pkg/service/gemini.go index 4d7bc8b..1b775da 100644 --- a/pkg/service/gemini.go +++ b/pkg/service/gemini.go @@ -2,45 +2,61 @@ package service import ( "context" + "sync" + "github.com/google/generative-ai-go/genai" "google.golang.org/api/option" + "tencent_ocr/pkg/errors" ) type GeminiService struct { - apiKey string client *genai.Client + model *genai.GenerativeModel + apiKey string + mu sync.RWMutex } func NewGeminiService(apiKey string) (*GeminiService, error) { ctx := context.Background() client, err := genai.NewClient(ctx, option.WithAPIKey(apiKey)) if err != nil { - return nil, err + return nil, errors.Wrap(err, "failed to create Gemini client") } return &GeminiService{ - apiKey: apiKey, - client: client, + client: client, + model: client.GenerativeModel("gemini-pro"), + apiKey: apiKey, }, nil } -func (s *GeminiService) Close() { +func (s *GeminiService) ProcessText(ctx context.Context, text string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + resp, err := s.model.GenerateContent(ctx, genai.Text(text)) + if err != nil { + return "", errors.Wrap(err, "failed to generate content") + } + + if len(resp.Candidates) == 0 || len(resp.Candidates[0].Content.Parts) == 0 { + return "", errors.New("no response from Gemini") + } + + if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { + return string(textPart), nil + } + + return "", errors.New("invalid response format from Gemini") +} + +// Close implements graceful shutdown +func (s *GeminiService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + if s.client != nil { s.client.Close() } -} - -func (s *GeminiService) ProcessText(ctx context.Context, prompt string) (string, error) { - model := s.client.GenerativeModel("gemini-2.0-flash-exp") - resp, err := model.GenerateContent(ctx, genai.Text(prompt)) - if err != nil { - return "", err - } - - if len(resp.Candidates) > 0 && len(resp.Candidates[0].Content.Parts) > 0 { - if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { - return string(textPart), nil - } - } - return "", nil + return nil } \ No newline at end of file diff --git a/pkg/service/ocr.go b/pkg/service/ocr.go new file mode 100644 index 0000000..489b68a --- /dev/null +++ b/pkg/service/ocr.go @@ -0,0 +1,71 @@ +package service + +import ( + "context" + "sync" + + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "tencent_ocr/pkg/errors" +) + +type OCRService struct { + client *ocr.Client + tencentSecretID string + tencentSecretKey string + mu sync.RWMutex +} + +func NewOCRService(tencentSecretID, tencentSecretKey string) (*OCRService, error) { + credential := common.NewCredential(tencentSecretID, tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return nil, errors.Wrap(err, "failed to create Tencent Cloud OCR client") + } + + return &OCRService{ + client: client, + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + }, nil +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + if imageBase64 == "" { + return "", errors.NewClientError("image data is required") + } + + s.mu.RLock() + defer s.mu.RUnlock() + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := s.client.GeneralHandwritingOCRWithContext(ctx, request) + if err != nil { + return "", errors.Wrap(err, "failed to perform OCR") + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +// Close implements graceful shutdown +func (s *OCRService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Add any cleanup logic here if needed + return nil +} \ No newline at end of file diff --git a/pkg/service/upload.go b/pkg/service/upload.go new file mode 100644 index 0000000..8e5bd8b --- /dev/null +++ b/pkg/service/upload.go @@ -0,0 +1,92 @@ +package service + +import ( + "bytes" + "fmt" + "io" + + "sync" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "tencent_ocr/pkg/errors" +) + +type UploadService struct { + s3Client *s3.S3 + bucket string + customDomain string + mu sync.RWMutex +} + +func NewUploadService(accessKey, secretKey, bucket, endpoint, customDomain string) (*UploadService, error) { + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(accessKey, secretKey, ""), + }) + if err != nil { + return nil, errors.Wrap(err, "failed to create S3 session") + } + + return &UploadService{ + s3Client: s3.New(sess), + bucket: bucket, + customDomain: customDomain, + }, nil +} + +func (s *UploadService) UploadFile(file io.Reader, fileName, contentType string) (string, error) { + s.mu.RLock() + defer s.mu.RUnlock() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + return "", errors.Wrap(err, "failed to read file") + } + + // Upload file to R2 + _, err = s.s3Client.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(s.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(fileBytes), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), + }) + if err != nil { + return "", errors.Wrap(err, "failed to upload file to R2") + } + + // Generate file URL + imageURL := fmt.Sprintf("https://%s/%s", s.customDomain, fileName) + return imageURL, nil +} + +func (s *UploadService) IsValidFileType(contentType string) bool { + allowedTypes := []string{ + "image/jpeg", + "image/png", + "image/gif", + "image/bmp", + "image/tiff", + "image/webp", + } + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} + +// Close implements graceful shutdown +func (s *UploadService) Close() error { + s.mu.Lock() + defer s.mu.Unlock() + + // Add any cleanup logic here if needed + return nil +} \ No newline at end of file