diff --git a/.history/pkg/handler/ocr_20250115160204.go b/.history/pkg/handler/ocr_20250115160204.go new file mode 100644 index 0000000..dd6ba51 --- /dev/null +++ b/.history/pkg/handler/ocr_20250115160204.go @@ -0,0 +1,127 @@ +package handler + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "google.golang.org/api/option" + "git.disbaidu.com/maxwell/" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115160239.go b/.history/pkg/handler/ocr_20250115160239.go new file mode 100644 index 0000000..cf92e35 --- /dev/null +++ b/.history/pkg/handler/ocr_20250115160239.go @@ -0,0 +1,127 @@ +package handler + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "google.golang.org/api/option" + "git.disbaidu.com/maxwell/tencent_ocr/" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115160319.go b/.history/pkg/handler/ocr_20250115160319.go new file mode 100644 index 0000000..a242cbf --- /dev/null +++ b/.history/pkg/handler/ocr_20250115160319.go @@ -0,0 +1,127 @@ +package handler + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "google.golang.org/api/option" + "git.disbaidu.com/maxwell/tencent_ocr/src/branch/main/pkg/service" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115160801.go b/.history/pkg/handler/ocr_20250115160801.go new file mode 100644 index 0000000..6ab82ac --- /dev/null +++ b/.history/pkg/handler/ocr_20250115160801.go @@ -0,0 +1,127 @@ +package handler + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "google.golang.org/api/option" + "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/ocr_20250115160803.go b/.history/pkg/handler/ocr_20250115160803.go new file mode 100644 index 0000000..6ab82ac --- /dev/null +++ b/.history/pkg/handler/ocr_20250115160803.go @@ -0,0 +1,127 @@ +package handler + +import ( + "context" + "encoding/base64" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common" + "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" + ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" + "google.golang.org/api/option" + "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" +) + +type OCRService struct { + tencentSecretID string + tencentSecretKey string + geminiService *service.GeminiService +} + +func NewOCRService(tencentSecretID, tencentSecretKey string, geminiService *service.GeminiService) *OCRService { + return &OCRService{ + tencentSecretID: tencentSecretID, + tencentSecretKey: tencentSecretKey, + geminiService: geminiService, + } +} + +func (s *OCRService) ProcessImage(ctx context.Context, imageBase64 string) (string, error) { + // Initialize Tencent Cloud client + credential := common.NewCredential(s.tencentSecretID, s.tencentSecretKey) + cpf := profile.NewClientProfile() + cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com" + client, err := ocr.NewClient(credential, "", cpf) + if err != nil { + return "", err + } + + // Create OCR request + request := ocr.NewGeneralHandwritingOCRRequest() + request.ImageBase64 = common.StringPtr(imageBase64) + + // Perform OCR + response, err := client.GeneralHandwritingOCR(request) + if err != nil { + return "", err + } + + // Extract text from OCR response + var ocrText string + for _, textDetection := range response.Response.TextDetections { + ocrText += *textDetection.DetectedText + "\n" + } + + return ocrText, nil +} + +type OCRRequest struct { + ImageBase64 string `json:"image_base64"` + ImageURL string `json:"image_url"` + Scene string `json:"scene"` + APIKey string `json:"apikey" binding:"required"` +} + +type OCRResponse struct { + OriginalText string `json:"original_text"` + Result string `json:"result"` + Success bool `json:"success"` +} + +func (h *OCRService) HandleOCR(c *gin.Context) { + var req OCRRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Invalid request format", + }) + return + } + + // Validate API key + if req.APIKey != h.geminiService.APIKey { + c.JSON(http.StatusUnauthorized, OCRResponse{ + Success: false, + Result: "Invalid API key", + }) + return + } + + // Validate that at least one of ImageURL or ImageBase64 is provided + if req.ImageURL == "" && req.ImageBase64 == "" { + c.JSON(http.StatusBadRequest, OCRResponse{ + Success: false, + Result: "Either image_url or image_base64 must be provided", + }) + return + } + + // Process image + ocrText, err := h.ProcessImage(c.Request.Context(), req.ImageBase64) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "OCR processing failed", + }) + return + } + + // Process with Gemini + processedText, err := h.geminiService.ProcessText(c.Request.Context(), ocrText) + if err != nil { + c.JSON(http.StatusInternalServerError, OCRResponse{ + Success: false, + Result: "Text processing failed", + }) + return + } + + c.JSON(http.StatusOK, OCRResponse{ + Success: true, + OriginalText: ocrText, + Result: processedText, + }) +} \ No newline at end of file diff --git a/.history/pkg/handler/rate_20250115141957.go b/.history/pkg/handler/rate_20250115141957.go new file mode 100644 index 0000000..4a2f8f0 --- /dev/null +++ b/.history/pkg/handler/rate_20250115141957.go @@ -0,0 +1,176 @@ +package handler + +import ( + "net/http" + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/option" + "encoding/json" + "strings" +) + +type RateHandler struct { + geminiAPIKey string + apiKey string +} + +type RateRequest struct { + Content string `json:"content" binding:"required"` + Criteria string `json:"criteria"` + WritingRequirement string `json:"writing_requirement"` + APIKey string `json:"apikey" binding:"required"` +} + +type RateResponse struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + DetailedReview string `json:"detailed_review"` + Success bool `json:"success"` +} + +func NewRateHandler(geminiAPIKey, apiKey string) *RateHandler { + return &RateHandler{ + geminiAPIKey: geminiAPIKey, + apiKey: apiKey, + } +} + +func (h *RateHandler) HandleRate(c *gin.Context) { + var req RateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, RateResponse{ + Success: false, + }) + return + } + + // Validate API key + if req.APIKey != h.apiKey { + c.JSON(http.StatusUnauthorized, RateResponse{ + Success: false, + }) + return + } + + // Initialize Gemini client + ctx := c.Request.Context() + client, err := genai.NewClient(ctx, option.WithAPIKey(h.geminiAPIKey)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + defer client.Close() + + // Prepare criteria + criteria := req.Criteria + if criteria == "" { + criteria = `你是一名语文老师。你正在给学生的作文打分。根据以下中考作文评分标准,给作文打分。 +## 评分总分值:100分。 +### 88-100分 符合题意;写作目的和对象明确;思考充分,立意深刻,感情真挚;选材精当,内容充实;中心突出,条理清晰;表达准确,语言流畅。 +### 75-87分 符合题意;写作目的和对象较明确;思考较充分,立意清楚,感情真实;选材合理,内容具体;中心明确,有一定条理;表达较准确,语言通畅。 +### 60-74分 符合题意;写作目的和对象较模糊;有一定思考,感情真实;有一定内容;结构基本完整;语言尚通畅。 +### 60分以下 不符合题意;缺乏写作目的和对象;基本没有思考,感情虚假;内容空洞;结构混乱;不成篇。` + } + writing_requirement := req.WritingRequirement + if writing_requirement == "" { + writing_requirement = "写一篇不少于600字的作文,体裁不限。" + } + + // 规定输出格式是json,包含rate, summary, detailed_review,放入prompt的最后 + format := `请按照以下JSON格式输出: +{ + "rate": 分数, // 最多100分制的分数,int类型 + "summary": "总体评价", // 100字以内的总体评价,string类型 + "detailed_review": "详细点评" // 300字以内的详细点评,包含优点和建议,string类型 +}` + // Prepare prompt + prompt := "作文要求:\n" + writing_requirement + "\n\n" + "评分标准:\n" + criteria + format + "\n\n" + "\n\n作文内容:\n" + req.Content + + // Generate content + model := client.GenerativeModel("gemini-2.0-flash-exp") + resp, err := model.GenerateContent(ctx, genai.Text(prompt)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + + if len(resp.Candidates) > 0 && len(resp.Candidates[0].Content.Parts) > 0 { + if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { + // Parse the response to extract rate, summary, and detailed review + result := parseRateResponse(string(textPart)) + + c.JSON(http.StatusOK, RateResponse{ + Rate: result.Rate, + Summary: result.Summary, + DetailedReview: result.Detailed, + Success: true, + }) + return + } + } + + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) +} + +type rateResult struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + Detailed string `json:"detailed_review"` +} + +func parseRateResponse(response string) rateResult { + var result rateResult + //去除所有\n + response = strings.ReplaceAll(response, "\n", "") + //去除所有\t + response = strings.ReplaceAll(response, "\t", "") + // 去除response中的```json前缀和```后缀 + response = strings.TrimSpace(response) + response = strings.TrimPrefix(response, "```json") + response = strings.TrimSuffix(response, "```") + + + // 检查response是否是json格式 + if !strings.HasPrefix(response, "{") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有左括号", + } + } + if !strings.HasSuffix(response, "}") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有右括号", + } + } + + // 解析json + err := json.Unmarshal([]byte(response), &result) + if err != nil { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "反序列化失败", + } + } + + // 合并所有验证条件 + if result.Rate <= 0 || result.Rate > 100 || + result.Summary == "" || result.Detailed == "" { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "字段验证条件不满足", + } + } + + return result +} \ No newline at end of file diff --git a/.history/pkg/handler/rate_20250115160350.go b/.history/pkg/handler/rate_20250115160350.go new file mode 100644 index 0000000..31e90a9 --- /dev/null +++ b/.history/pkg/handler/rate_20250115160350.go @@ -0,0 +1,177 @@ +package handler + +import ( + "net/http" + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/option" + "encoding/json" + "strings" + "git.disbaidu.com/maxwell/tencent_ocr/src/branch/main/pkg/service" +) + +type RateHandler struct { + geminiAPIKey string + apiKey string +} + +type RateRequest struct { + Content string `json:"content" binding:"required"` + Criteria string `json:"criteria"` + WritingRequirement string `json:"writing_requirement"` + APIKey string `json:"apikey" binding:"required"` +} + +type RateResponse struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + DetailedReview string `json:"detailed_review"` + Success bool `json:"success"` +} + +func NewRateHandler(geminiAPIKey, apiKey string) *RateHandler { + return &RateHandler{ + geminiAPIKey: geminiAPIKey, + apiKey: apiKey, + } +} + +func (h *RateHandler) HandleRate(c *gin.Context) { + var req RateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, RateResponse{ + Success: false, + }) + return + } + + // Validate API key + if req.APIKey != h.apiKey { + c.JSON(http.StatusUnauthorized, RateResponse{ + Success: false, + }) + return + } + + // Initialize Gemini client + ctx := c.Request.Context() + client, err := genai.NewClient(ctx, option.WithAPIKey(h.geminiAPIKey)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + defer client.Close() + + // Prepare criteria + criteria := req.Criteria + if criteria == "" { + criteria = `你是一名语文老师。你正在给学生的作文打分。根据以下中考作文评分标准,给作文打分。 +## 评分总分值:100分。 +### 88-100分 符合题意;写作目的和对象明确;思考充分,立意深刻,感情真挚;选材精当,内容充实;中心突出,条理清晰;表达准确,语言流畅。 +### 75-87分 符合题意;写作目的和对象较明确;思考较充分,立意清楚,感情真实;选材合理,内容具体;中心明确,有一定条理;表达较准确,语言通畅。 +### 60-74分 符合题意;写作目的和对象较模糊;有一定思考,感情真实;有一定内容;结构基本完整;语言尚通畅。 +### 60分以下 不符合题意;缺乏写作目的和对象;基本没有思考,感情虚假;内容空洞;结构混乱;不成篇。` + } + writing_requirement := req.WritingRequirement + if writing_requirement == "" { + writing_requirement = "写一篇不少于600字的作文,体裁不限。" + } + + // 规定输出格式是json,包含rate, summary, detailed_review,放入prompt的最后 + format := `请按照以下JSON格式输出: +{ + "rate": 分数, // 最多100分制的分数,int类型 + "summary": "总体评价", // 100字以内的总体评价,string类型 + "detailed_review": "详细点评" // 300字以内的详细点评,包含优点和建议,string类型 +}` + // Prepare prompt + prompt := "作文要求:\n" + writing_requirement + "\n\n" + "评分标准:\n" + criteria + format + "\n\n" + "\n\n作文内容:\n" + req.Content + + // Generate content + model := client.GenerativeModel("gemini-2.0-flash-exp") + resp, err := model.GenerateContent(ctx, genai.Text(prompt)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + + if len(resp.Candidates) > 0 && len(resp.Candidates[0].Content.Parts) > 0 { + if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { + // Parse the response to extract rate, summary, and detailed review + result := parseRateResponse(string(textPart)) + + c.JSON(http.StatusOK, RateResponse{ + Rate: result.Rate, + Summary: result.Summary, + DetailedReview: result.Detailed, + Success: true, + }) + return + } + } + + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) +} + +type rateResult struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + Detailed string `json:"detailed_review"` +} + +func parseRateResponse(response string) rateResult { + var result rateResult + //去除所有\n + response = strings.ReplaceAll(response, "\n", "") + //去除所有\t + response = strings.ReplaceAll(response, "\t", "") + // 去除response中的```json前缀和```后缀 + response = strings.TrimSpace(response) + response = strings.TrimPrefix(response, "```json") + response = strings.TrimSuffix(response, "```") + + + // 检查response是否是json格式 + if !strings.HasPrefix(response, "{") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有左括号", + } + } + if !strings.HasSuffix(response, "}") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有右括号", + } + } + + // 解析json + err := json.Unmarshal([]byte(response), &result) + if err != nil { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "反序列化失败", + } + } + + // 合并所有验证条件 + if result.Rate <= 0 || result.Rate > 100 || + result.Summary == "" || result.Detailed == "" { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "字段验证条件不满足", + } + } + + return result +} \ No newline at end of file diff --git a/.history/pkg/handler/rate_20250115160752.go b/.history/pkg/handler/rate_20250115160752.go new file mode 100644 index 0000000..157039a --- /dev/null +++ b/.history/pkg/handler/rate_20250115160752.go @@ -0,0 +1,178 @@ +package handler + +import ( + "net/http" + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/option" + "encoding/json" + "strings" + "git.disbaidu.com/maxwell/tencent_ocr/src/branch/main/pkg/service" + "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" +) + +type RateHandler struct { + geminiAPIKey string + apiKey string +} + +type RateRequest struct { + Content string `json:"content" binding:"required"` + Criteria string `json:"criteria"` + WritingRequirement string `json:"writing_requirement"` + APIKey string `json:"apikey" binding:"required"` +} + +type RateResponse struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + DetailedReview string `json:"detailed_review"` + Success bool `json:"success"` +} + +func NewRateHandler(geminiAPIKey, apiKey string) *RateHandler { + return &RateHandler{ + geminiAPIKey: geminiAPIKey, + apiKey: apiKey, + } +} + +func (h *RateHandler) HandleRate(c *gin.Context) { + var req RateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, RateResponse{ + Success: false, + }) + return + } + + // Validate API key + if req.APIKey != h.apiKey { + c.JSON(http.StatusUnauthorized, RateResponse{ + Success: false, + }) + return + } + + // Initialize Gemini client + ctx := c.Request.Context() + client, err := genai.NewClient(ctx, option.WithAPIKey(h.geminiAPIKey)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + defer client.Close() + + // Prepare criteria + criteria := req.Criteria + if criteria == "" { + criteria = `你是一名语文老师。你正在给学生的作文打分。根据以下中考作文评分标准,给作文打分。 +## 评分总分值:100分。 +### 88-100分 符合题意;写作目的和对象明确;思考充分,立意深刻,感情真挚;选材精当,内容充实;中心突出,条理清晰;表达准确,语言流畅。 +### 75-87分 符合题意;写作目的和对象较明确;思考较充分,立意清楚,感情真实;选材合理,内容具体;中心明确,有一定条理;表达较准确,语言通畅。 +### 60-74分 符合题意;写作目的和对象较模糊;有一定思考,感情真实;有一定内容;结构基本完整;语言尚通畅。 +### 60分以下 不符合题意;缺乏写作目的和对象;基本没有思考,感情虚假;内容空洞;结构混乱;不成篇。` + } + writing_requirement := req.WritingRequirement + if writing_requirement == "" { + writing_requirement = "写一篇不少于600字的作文,体裁不限。" + } + + // 规定输出格式是json,包含rate, summary, detailed_review,放入prompt的最后 + format := `请按照以下JSON格式输出: +{ + "rate": 分数, // 最多100分制的分数,int类型 + "summary": "总体评价", // 100字以内的总体评价,string类型 + "detailed_review": "详细点评" // 300字以内的详细点评,包含优点和建议,string类型 +}` + // Prepare prompt + prompt := "作文要求:\n" + writing_requirement + "\n\n" + "评分标准:\n" + criteria + format + "\n\n" + "\n\n作文内容:\n" + req.Content + + // Generate content + model := client.GenerativeModel("gemini-2.0-flash-exp") + resp, err := model.GenerateContent(ctx, genai.Text(prompt)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + + if len(resp.Candidates) > 0 && len(resp.Candidates[0].Content.Parts) > 0 { + if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { + // Parse the response to extract rate, summary, and detailed review + result := parseRateResponse(string(textPart)) + + c.JSON(http.StatusOK, RateResponse{ + Rate: result.Rate, + Summary: result.Summary, + DetailedReview: result.Detailed, + Success: true, + }) + return + } + } + + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) +} + +type rateResult struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + Detailed string `json:"detailed_review"` +} + +func parseRateResponse(response string) rateResult { + var result rateResult + //去除所有\n + response = strings.ReplaceAll(response, "\n", "") + //去除所有\t + response = strings.ReplaceAll(response, "\t", "") + // 去除response中的```json前缀和```后缀 + response = strings.TrimSpace(response) + response = strings.TrimPrefix(response, "```json") + response = strings.TrimSuffix(response, "```") + + + // 检查response是否是json格式 + if !strings.HasPrefix(response, "{") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有左括号", + } + } + if !strings.HasSuffix(response, "}") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有右括号", + } + } + + // 解析json + err := json.Unmarshal([]byte(response), &result) + if err != nil { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "反序列化失败", + } + } + + // 合并所有验证条件 + if result.Rate <= 0 || result.Rate > 100 || + result.Summary == "" || result.Detailed == "" { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "字段验证条件不满足", + } + } + + return result +} \ No newline at end of file diff --git a/.history/pkg/handler/rate_20250115160758.go b/.history/pkg/handler/rate_20250115160758.go new file mode 100644 index 0000000..bb3e3e8 --- /dev/null +++ b/.history/pkg/handler/rate_20250115160758.go @@ -0,0 +1,177 @@ +package handler + +import ( + "net/http" + "github.com/gin-gonic/gin" + "github.com/google/generative-ai-go/genai" + "google.golang.org/api/option" + "encoding/json" + "strings" + "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" +) + +type RateHandler struct { + geminiAPIKey string + apiKey string +} + +type RateRequest struct { + Content string `json:"content" binding:"required"` + Criteria string `json:"criteria"` + WritingRequirement string `json:"writing_requirement"` + APIKey string `json:"apikey" binding:"required"` +} + +type RateResponse struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + DetailedReview string `json:"detailed_review"` + Success bool `json:"success"` +} + +func NewRateHandler(geminiAPIKey, apiKey string) *RateHandler { + return &RateHandler{ + geminiAPIKey: geminiAPIKey, + apiKey: apiKey, + } +} + +func (h *RateHandler) HandleRate(c *gin.Context) { + var req RateRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, RateResponse{ + Success: false, + }) + return + } + + // Validate API key + if req.APIKey != h.apiKey { + c.JSON(http.StatusUnauthorized, RateResponse{ + Success: false, + }) + return + } + + // Initialize Gemini client + ctx := c.Request.Context() + client, err := genai.NewClient(ctx, option.WithAPIKey(h.geminiAPIKey)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + defer client.Close() + + // Prepare criteria + criteria := req.Criteria + if criteria == "" { + criteria = `你是一名语文老师。你正在给学生的作文打分。根据以下中考作文评分标准,给作文打分。 +## 评分总分值:100分。 +### 88-100分 符合题意;写作目的和对象明确;思考充分,立意深刻,感情真挚;选材精当,内容充实;中心突出,条理清晰;表达准确,语言流畅。 +### 75-87分 符合题意;写作目的和对象较明确;思考较充分,立意清楚,感情真实;选材合理,内容具体;中心明确,有一定条理;表达较准确,语言通畅。 +### 60-74分 符合题意;写作目的和对象较模糊;有一定思考,感情真实;有一定内容;结构基本完整;语言尚通畅。 +### 60分以下 不符合题意;缺乏写作目的和对象;基本没有思考,感情虚假;内容空洞;结构混乱;不成篇。` + } + writing_requirement := req.WritingRequirement + if writing_requirement == "" { + writing_requirement = "写一篇不少于600字的作文,体裁不限。" + } + + // 规定输出格式是json,包含rate, summary, detailed_review,放入prompt的最后 + format := `请按照以下JSON格式输出: +{ + "rate": 分数, // 最多100分制的分数,int类型 + "summary": "总体评价", // 100字以内的总体评价,string类型 + "detailed_review": "详细点评" // 300字以内的详细点评,包含优点和建议,string类型 +}` + // Prepare prompt + prompt := "作文要求:\n" + writing_requirement + "\n\n" + "评分标准:\n" + criteria + format + "\n\n" + "\n\n作文内容:\n" + req.Content + + // Generate content + model := client.GenerativeModel("gemini-2.0-flash-exp") + resp, err := model.GenerateContent(ctx, genai.Text(prompt)) + if err != nil { + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) + return + } + + if len(resp.Candidates) > 0 && len(resp.Candidates[0].Content.Parts) > 0 { + if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok { + // Parse the response to extract rate, summary, and detailed review + result := parseRateResponse(string(textPart)) + + c.JSON(http.StatusOK, RateResponse{ + Rate: result.Rate, + Summary: result.Summary, + DetailedReview: result.Detailed, + Success: true, + }) + return + } + } + + c.JSON(http.StatusInternalServerError, RateResponse{ + Success: false, + }) +} + +type rateResult struct { + Rate int `json:"rate"` + Summary string `json:"summary"` + Detailed string `json:"detailed_review"` +} + +func parseRateResponse(response string) rateResult { + var result rateResult + //去除所有\n + response = strings.ReplaceAll(response, "\n", "") + //去除所有\t + response = strings.ReplaceAll(response, "\t", "") + // 去除response中的```json前缀和```后缀 + response = strings.TrimSpace(response) + response = strings.TrimPrefix(response, "```json") + response = strings.TrimSuffix(response, "```") + + + // 检查response是否是json格式 + if !strings.HasPrefix(response, "{") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有左括号", + } + } + if !strings.HasSuffix(response, "}") { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "没有右括号", + } + } + + // 解析json + err := json.Unmarshal([]byte(response), &result) + if err != nil { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "反序列化失败", + } + } + + // 合并所有验证条件 + if result.Rate <= 0 || result.Rate > 100 || + result.Summary == "" || result.Detailed == "" { + return rateResult{ + Rate: 0, + Summary: "解析失败", + Detailed: "字段验证条件不满足", + } + } + + return result +} \ No newline at end of file diff --git a/.history/pkg/handler/upload_20250115160358.go b/.history/pkg/handler/upload_20250115160358.go new file mode 100644 index 0000000..4bb8054 --- /dev/null +++ b/.history/pkg/handler/upload_20250115160358.go @@ -0,0 +1,162 @@ +// 上传文件到cloudflare R2 +package handler +import ( + "bytes" + "fmt" + "net/http" + "github.com/gin-gonic/gin" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "encoding/base64" + "io" + "strings" + "git.disbaidu.com/maxwell/tencent_ocr/src/branch/main/pkg/service" +) + +type UploadHandler struct { + accessKey string + secretKey string + bucket string + endpoint string + customDomain string + ocrService *OCRService + geminiService *service.GeminiService +} + +type MultiUploadResponse struct { + ImageURLs []string `json:"image_urls"` + Text string `json:"text"` + Success bool `json:"success"` +} + +func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse form"}) + return + } + + files := form.File["files"] + if len(files) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "No files uploaded"}) + return + } + + if len(files) > 5 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Maximum 5 files allowed"}) + return + } + + var imageURLs []string + var ocrTexts []string + + for _, fileHeader := range files { + if fileHeader.Size > 10<<20 { // 10MB + c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit of 10MB"}) + return + } + + file, err := fileHeader.Open() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open file"}) + return + } + defer file.Close() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read file"}) + return + } + + // Verify file type + contentType := http.DetectContentType(fileBytes) + if !isImage(contentType) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type. Only images are allowed"}) + return + } + + // Convert to base64 + base64Str := base64.StdEncoding.EncodeToString(fileBytes) + + // Process OCR + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "OCR processing failed"}) + return + } + ocrTexts = append(ocrTexts, ocrText) + + // Upload to R2 + imageURL, err := h.uploadToR2(fileBytes, fileHeader.Filename, contentType) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload file"}) + return + } + imageURLs = append(imageURLs, imageURL) + } + + // Process combined text with Gemini if multiple images + finalText := strings.Join(ocrTexts, "\n") + if len(ocrTexts) > 1 { + prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText + processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Text processing failed"}) + return + } + finalText = processedText + } + + c.JSON(http.StatusOK, MultiUploadResponse{ + ImageURLs: imageURLs, + Text: finalText, + Success: true, + }) +} + +// uploadToR2 上传文件到Cloudflare R2 +func (h *UploadHandler) uploadToR2(file []byte, fileName, contentType string) (string, error) { + // 创建S3会话 + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(h.endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(h.accessKey, h.secretKey, ""), + }) + if err != nil { + return "", fmt.Errorf("failed to create S3 session: %v", err) + } + + // 创建S3服务客户端 + svc := s3.New(sess) + + // 上传文件到R2 + _, err = svc.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(h.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(file), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), // 设置文件为公开可读 + }) + if err != nil { + return "", fmt.Errorf("failed to upload file to R2: %v", err) + } + + // 生成文件的URL + imageURL := fmt.Sprintf("https://%s/%s", h.customDomain, fileName) + return imageURL, nil +} + +// isImage 检查文件是否是图片 +func isImage(contentType string) bool { + allowedTypes := []string{"image/jpeg", "image/png", "image/gif", "image/bmp", "image/tiff", "image/webp"} + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} \ No newline at end of file diff --git a/.history/pkg/handler/upload_20250115160402.go b/.history/pkg/handler/upload_20250115160402.go new file mode 100644 index 0000000..4bb8054 --- /dev/null +++ b/.history/pkg/handler/upload_20250115160402.go @@ -0,0 +1,162 @@ +// 上传文件到cloudflare R2 +package handler +import ( + "bytes" + "fmt" + "net/http" + "github.com/gin-gonic/gin" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "encoding/base64" + "io" + "strings" + "git.disbaidu.com/maxwell/tencent_ocr/src/branch/main/pkg/service" +) + +type UploadHandler struct { + accessKey string + secretKey string + bucket string + endpoint string + customDomain string + ocrService *OCRService + geminiService *service.GeminiService +} + +type MultiUploadResponse struct { + ImageURLs []string `json:"image_urls"` + Text string `json:"text"` + Success bool `json:"success"` +} + +func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse form"}) + return + } + + files := form.File["files"] + if len(files) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "No files uploaded"}) + return + } + + if len(files) > 5 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Maximum 5 files allowed"}) + return + } + + var imageURLs []string + var ocrTexts []string + + for _, fileHeader := range files { + if fileHeader.Size > 10<<20 { // 10MB + c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit of 10MB"}) + return + } + + file, err := fileHeader.Open() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open file"}) + return + } + defer file.Close() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read file"}) + return + } + + // Verify file type + contentType := http.DetectContentType(fileBytes) + if !isImage(contentType) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type. Only images are allowed"}) + return + } + + // Convert to base64 + base64Str := base64.StdEncoding.EncodeToString(fileBytes) + + // Process OCR + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "OCR processing failed"}) + return + } + ocrTexts = append(ocrTexts, ocrText) + + // Upload to R2 + imageURL, err := h.uploadToR2(fileBytes, fileHeader.Filename, contentType) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload file"}) + return + } + imageURLs = append(imageURLs, imageURL) + } + + // Process combined text with Gemini if multiple images + finalText := strings.Join(ocrTexts, "\n") + if len(ocrTexts) > 1 { + prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText + processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Text processing failed"}) + return + } + finalText = processedText + } + + c.JSON(http.StatusOK, MultiUploadResponse{ + ImageURLs: imageURLs, + Text: finalText, + Success: true, + }) +} + +// uploadToR2 上传文件到Cloudflare R2 +func (h *UploadHandler) uploadToR2(file []byte, fileName, contentType string) (string, error) { + // 创建S3会话 + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(h.endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(h.accessKey, h.secretKey, ""), + }) + if err != nil { + return "", fmt.Errorf("failed to create S3 session: %v", err) + } + + // 创建S3服务客户端 + svc := s3.New(sess) + + // 上传文件到R2 + _, err = svc.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(h.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(file), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), // 设置文件为公开可读 + }) + if err != nil { + return "", fmt.Errorf("failed to upload file to R2: %v", err) + } + + // 生成文件的URL + imageURL := fmt.Sprintf("https://%s/%s", h.customDomain, fileName) + return imageURL, nil +} + +// isImage 检查文件是否是图片 +func isImage(contentType string) bool { + allowedTypes := []string{"image/jpeg", "image/png", "image/gif", "image/bmp", "image/tiff", "image/webp"} + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} \ No newline at end of file diff --git a/.history/pkg/handler/upload_20250115160721.go b/.history/pkg/handler/upload_20250115160721.go new file mode 100644 index 0000000..c0559e4 --- /dev/null +++ b/.history/pkg/handler/upload_20250115160721.go @@ -0,0 +1,162 @@ +// 上传文件到cloudflare R2 +package handler +import ( + "bytes" + "fmt" + "net/http" + "github.com/gin-gonic/gin" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "encoding/base64" + "io" + "strings" + "tencent_ocr/src/branch/main/pkg/service" +) + +type UploadHandler struct { + accessKey string + secretKey string + bucket string + endpoint string + customDomain string + ocrService *OCRService + geminiService *service.GeminiService +} + +type MultiUploadResponse struct { + ImageURLs []string `json:"image_urls"` + Text string `json:"text"` + Success bool `json:"success"` +} + +func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse form"}) + return + } + + files := form.File["files"] + if len(files) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "No files uploaded"}) + return + } + + if len(files) > 5 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Maximum 5 files allowed"}) + return + } + + var imageURLs []string + var ocrTexts []string + + for _, fileHeader := range files { + if fileHeader.Size > 10<<20 { // 10MB + c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit of 10MB"}) + return + } + + file, err := fileHeader.Open() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open file"}) + return + } + defer file.Close() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read file"}) + return + } + + // Verify file type + contentType := http.DetectContentType(fileBytes) + if !isImage(contentType) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type. Only images are allowed"}) + return + } + + // Convert to base64 + base64Str := base64.StdEncoding.EncodeToString(fileBytes) + + // Process OCR + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "OCR processing failed"}) + return + } + ocrTexts = append(ocrTexts, ocrText) + + // Upload to R2 + imageURL, err := h.uploadToR2(fileBytes, fileHeader.Filename, contentType) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload file"}) + return + } + imageURLs = append(imageURLs, imageURL) + } + + // Process combined text with Gemini if multiple images + finalText := strings.Join(ocrTexts, "\n") + if len(ocrTexts) > 1 { + prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText + processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Text processing failed"}) + return + } + finalText = processedText + } + + c.JSON(http.StatusOK, MultiUploadResponse{ + ImageURLs: imageURLs, + Text: finalText, + Success: true, + }) +} + +// uploadToR2 上传文件到Cloudflare R2 +func (h *UploadHandler) uploadToR2(file []byte, fileName, contentType string) (string, error) { + // 创建S3会话 + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(h.endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(h.accessKey, h.secretKey, ""), + }) + if err != nil { + return "", fmt.Errorf("failed to create S3 session: %v", err) + } + + // 创建S3服务客户端 + svc := s3.New(sess) + + // 上传文件到R2 + _, err = svc.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(h.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(file), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), // 设置文件为公开可读 + }) + if err != nil { + return "", fmt.Errorf("failed to upload file to R2: %v", err) + } + + // 生成文件的URL + imageURL := fmt.Sprintf("https://%s/%s", h.customDomain, fileName) + return imageURL, nil +} + +// isImage 检查文件是否是图片 +func isImage(contentType string) bool { + allowedTypes := []string{"image/jpeg", "image/png", "image/gif", "image/bmp", "image/tiff", "image/webp"} + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} \ No newline at end of file diff --git a/.history/pkg/handler/upload_20250115160729.go b/.history/pkg/handler/upload_20250115160729.go new file mode 100644 index 0000000..4bb8054 --- /dev/null +++ b/.history/pkg/handler/upload_20250115160729.go @@ -0,0 +1,162 @@ +// 上传文件到cloudflare R2 +package handler +import ( + "bytes" + "fmt" + "net/http" + "github.com/gin-gonic/gin" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "encoding/base64" + "io" + "strings" + "git.disbaidu.com/maxwell/tencent_ocr/src/branch/main/pkg/service" +) + +type UploadHandler struct { + accessKey string + secretKey string + bucket string + endpoint string + customDomain string + ocrService *OCRService + geminiService *service.GeminiService +} + +type MultiUploadResponse struct { + ImageURLs []string `json:"image_urls"` + Text string `json:"text"` + Success bool `json:"success"` +} + +func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse form"}) + return + } + + files := form.File["files"] + if len(files) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "No files uploaded"}) + return + } + + if len(files) > 5 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Maximum 5 files allowed"}) + return + } + + var imageURLs []string + var ocrTexts []string + + for _, fileHeader := range files { + if fileHeader.Size > 10<<20 { // 10MB + c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit of 10MB"}) + return + } + + file, err := fileHeader.Open() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open file"}) + return + } + defer file.Close() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read file"}) + return + } + + // Verify file type + contentType := http.DetectContentType(fileBytes) + if !isImage(contentType) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type. Only images are allowed"}) + return + } + + // Convert to base64 + base64Str := base64.StdEncoding.EncodeToString(fileBytes) + + // Process OCR + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "OCR processing failed"}) + return + } + ocrTexts = append(ocrTexts, ocrText) + + // Upload to R2 + imageURL, err := h.uploadToR2(fileBytes, fileHeader.Filename, contentType) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload file"}) + return + } + imageURLs = append(imageURLs, imageURL) + } + + // Process combined text with Gemini if multiple images + finalText := strings.Join(ocrTexts, "\n") + if len(ocrTexts) > 1 { + prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText + processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Text processing failed"}) + return + } + finalText = processedText + } + + c.JSON(http.StatusOK, MultiUploadResponse{ + ImageURLs: imageURLs, + Text: finalText, + Success: true, + }) +} + +// uploadToR2 上传文件到Cloudflare R2 +func (h *UploadHandler) uploadToR2(file []byte, fileName, contentType string) (string, error) { + // 创建S3会话 + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(h.endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(h.accessKey, h.secretKey, ""), + }) + if err != nil { + return "", fmt.Errorf("failed to create S3 session: %v", err) + } + + // 创建S3服务客户端 + svc := s3.New(sess) + + // 上传文件到R2 + _, err = svc.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(h.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(file), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), // 设置文件为公开可读 + }) + if err != nil { + return "", fmt.Errorf("failed to upload file to R2: %v", err) + } + + // 生成文件的URL + imageURL := fmt.Sprintf("https://%s/%s", h.customDomain, fileName) + return imageURL, nil +} + +// isImage 检查文件是否是图片 +func isImage(contentType string) bool { + allowedTypes := []string{"image/jpeg", "image/png", "image/gif", "image/bmp", "image/tiff", "image/webp"} + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} \ No newline at end of file diff --git a/.history/pkg/handler/upload_20250115160735.go b/.history/pkg/handler/upload_20250115160735.go new file mode 100644 index 0000000..cf4fa4a --- /dev/null +++ b/.history/pkg/handler/upload_20250115160735.go @@ -0,0 +1,162 @@ +// 上传文件到cloudflare R2 +package handler +import ( + "bytes" + "fmt" + "net/http" + "github.com/gin-gonic/gin" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/s3" + "encoding/base64" + "io" + "strings" + "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" +) + +type UploadHandler struct { + accessKey string + secretKey string + bucket string + endpoint string + customDomain string + ocrService *OCRService + geminiService *service.GeminiService +} + +type MultiUploadResponse struct { + ImageURLs []string `json:"image_urls"` + Text string `json:"text"` + Success bool `json:"success"` +} + +func (h *UploadHandler) HandleMultiUpload(c *gin.Context) { + form, err := c.MultipartForm() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse form"}) + return + } + + files := form.File["files"] + if len(files) == 0 { + c.JSON(http.StatusBadRequest, gin.H{"error": "No files uploaded"}) + return + } + + if len(files) > 5 { + c.JSON(http.StatusBadRequest, gin.H{"error": "Maximum 5 files allowed"}) + return + } + + var imageURLs []string + var ocrTexts []string + + for _, fileHeader := range files { + if fileHeader.Size > 10<<20 { // 10MB + c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit of 10MB"}) + return + } + + file, err := fileHeader.Open() + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open file"}) + return + } + defer file.Close() + + // Read file content + fileBytes, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read file"}) + return + } + + // Verify file type + contentType := http.DetectContentType(fileBytes) + if !isImage(contentType) { + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type. Only images are allowed"}) + return + } + + // Convert to base64 + base64Str := base64.StdEncoding.EncodeToString(fileBytes) + + // Process OCR + ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "OCR processing failed"}) + return + } + ocrTexts = append(ocrTexts, ocrText) + + // Upload to R2 + imageURL, err := h.uploadToR2(fileBytes, fileHeader.Filename, contentType) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload file"}) + return + } + imageURLs = append(imageURLs, imageURL) + } + + // Process combined text with Gemini if multiple images + finalText := strings.Join(ocrTexts, "\n") + if len(ocrTexts) > 1 { + prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText + processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "Text processing failed"}) + return + } + finalText = processedText + } + + c.JSON(http.StatusOK, MultiUploadResponse{ + ImageURLs: imageURLs, + Text: finalText, + Success: true, + }) +} + +// uploadToR2 上传文件到Cloudflare R2 +func (h *UploadHandler) uploadToR2(file []byte, fileName, contentType string) (string, error) { + // 创建S3会话 + sess, err := session.NewSession(&aws.Config{ + Endpoint: aws.String(h.endpoint), + Region: aws.String("auto"), + Credentials: credentials.NewStaticCredentials(h.accessKey, h.secretKey, ""), + }) + if err != nil { + return "", fmt.Errorf("failed to create S3 session: %v", err) + } + + // 创建S3服务客户端 + svc := s3.New(sess) + + // 上传文件到R2 + _, err = svc.PutObject(&s3.PutObjectInput{ + Bucket: aws.String(h.bucket), + Key: aws.String(fileName), + Body: bytes.NewReader(file), + ContentType: aws.String(contentType), + ACL: aws.String("public-read"), // 设置文件为公开可读 + }) + if err != nil { + return "", fmt.Errorf("failed to upload file to R2: %v", err) + } + + // 生成文件的URL + imageURL := fmt.Sprintf("https://%s/%s", h.customDomain, fileName) + return imageURL, nil +} + +// isImage 检查文件是否是图片 +func isImage(contentType string) bool { + allowedTypes := []string{"image/jpeg", "image/png", "image/gif", "image/bmp", "image/tiff", "image/webp"} + for _, t := range allowedTypes { + if contentType == t { + return true + } + } + return false +} \ No newline at end of file diff --git a/pkg/handler/ocr.go b/pkg/handler/ocr.go index 9f2d252..6ab82ac 100644 --- a/pkg/handler/ocr.go +++ b/pkg/handler/ocr.go @@ -12,7 +12,7 @@ import ( "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile" ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119" "google.golang.org/api/option" - "pkg/service" + "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" ) type OCRService struct { diff --git a/pkg/handler/rate.go b/pkg/handler/rate.go index 4a2f8f0..bb3e3e8 100644 --- a/pkg/handler/rate.go +++ b/pkg/handler/rate.go @@ -7,6 +7,7 @@ import ( "google.golang.org/api/option" "encoding/json" "strings" + "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" ) type RateHandler struct { diff --git a/pkg/handler/upload.go b/pkg/handler/upload.go index 630add1..cf4fa4a 100644 --- a/pkg/handler/upload.go +++ b/pkg/handler/upload.go @@ -12,7 +12,7 @@ import ( "encoding/base64" "io" "strings" - "your-project/pkg/service" + "git.disbaidu.com/maxwell/tencent_ocr/pkg/service" ) type UploadHandler struct {