166 lines
4.3 KiB
Go
166 lines
4.3 KiB
Go
package handler
|
||
|
||
import (
|
||
"encoding/base64"
|
||
"net/http"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/google/generative-ai-go/genai"
|
||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
||
"github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
||
ocr "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/ocr/v20181119"
|
||
"google.golang.org/api/option"
|
||
)
|
||
|
||
type OCRHandler struct {
|
||
tencentSecretID string
|
||
tencentSecretKey string
|
||
geminiAPIKey string
|
||
apiKey string
|
||
}
|
||
|
||
type OCRRequest struct {
|
||
ImageBase64 string `json:"image_base64"`
|
||
ImageURL string `json:"image_url"`
|
||
Scene string `json:"scene"`
|
||
APIKey string `json:"apikey" binding:"required"`
|
||
}
|
||
|
||
type OCRResponse struct {
|
||
OriginalText string `json:"original_text"`
|
||
Result string `json:"result"`
|
||
Success bool `json:"success"`
|
||
}
|
||
|
||
func NewOCRHandler(tencentSecretID, tencentSecretKey, geminiAPIKey, apiKey string) *OCRHandler {
|
||
return &OCRHandler{
|
||
tencentSecretID: tencentSecretID,
|
||
tencentSecretKey: tencentSecretKey,
|
||
geminiAPIKey: geminiAPIKey,
|
||
apiKey: apiKey,
|
||
}
|
||
}
|
||
|
||
func (h *OCRHandler) HandleOCR(c *gin.Context) {
|
||
var req OCRRequest
|
||
if err := c.ShouldBindJSON(&req); err != nil {
|
||
c.JSON(http.StatusBadRequest, OCRResponse{
|
||
Success: false,
|
||
Result: "Invalid request format",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Validate API key
|
||
if req.APIKey != h.apiKey {
|
||
c.JSON(http.StatusUnauthorized, OCRResponse{
|
||
Success: false,
|
||
Result: "Invalid API key",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Validate that at least one of ImageURL or ImageBase64 is provided
|
||
if req.ImageURL == "" && req.ImageBase64 == "" {
|
||
c.JSON(http.StatusBadRequest, OCRResponse{
|
||
Success: false,
|
||
Result: "Either image_url or image_base64 must be provided",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Initialize Tencent Cloud client
|
||
credential := common.NewCredential(h.tencentSecretID, h.tencentSecretKey)
|
||
cpf := profile.NewClientProfile()
|
||
cpf.HttpProfile.Endpoint = "ocr.tencentcloudapi.com"
|
||
client, err := ocr.NewClient(credential, "", cpf)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, OCRResponse{
|
||
Success: false,
|
||
Result: "Failed to initialize OCR client",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Create OCR request
|
||
request := ocr.NewGeneralHandwritingOCRRequest()
|
||
|
||
// Prioritize ImageURL if both are provided
|
||
if req.ImageURL != "" {
|
||
request.ImageUrl = common.StringPtr(req.ImageURL)
|
||
} else {
|
||
// Remove base64 prefix if exists
|
||
imageBase64 := req.ImageBase64
|
||
if idx := strings.Index(imageBase64, "base64,"); idx != -1 {
|
||
imageBase64 = imageBase64[idx+7:] // 7 is the length of "base64,"
|
||
}
|
||
|
||
// Validate base64
|
||
if _, err := base64.StdEncoding.DecodeString(imageBase64); err != nil {
|
||
c.JSON(http.StatusBadRequest, OCRResponse{
|
||
Success: false,
|
||
Result: "Invalid base64 image",
|
||
})
|
||
return
|
||
}
|
||
request.ImageBase64 = common.StringPtr(imageBase64)
|
||
}
|
||
|
||
if req.Scene != "" {
|
||
request.Scene = common.StringPtr(req.Scene)
|
||
}
|
||
|
||
// Perform OCR
|
||
response, err := client.GeneralHandwritingOCR(request)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, OCRResponse{
|
||
Success: false,
|
||
Result: "OCR processing failed",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Extract text from OCR response
|
||
var ocrText string
|
||
for _, textDetection := range response.Response.TextDetections {
|
||
ocrText += *textDetection.DetectedText + "\n"
|
||
}
|
||
|
||
// Process with Gemini
|
||
ctx := c.Request.Context()
|
||
client2, err := genai.NewClient(ctx, option.WithAPIKey(h.geminiAPIKey))
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, OCRResponse{
|
||
Success: false,
|
||
Result: "Failed to initialize Gemini client",
|
||
})
|
||
return
|
||
}
|
||
defer client2.Close()
|
||
|
||
model := client2.GenerativeModel("gemini-2.0-flash-exp")
|
||
prompt := "你是一个专业的助手,负责纠正OCR识别结果中的文本。只需要输出识别结果,不需要输出任何解释。\n\n" + ocrText
|
||
resp, err := model.GenerateContent(ctx, genai.Text(prompt))
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, OCRResponse{
|
||
Success: false,
|
||
Result: "Text processing failed",
|
||
})
|
||
return
|
||
}
|
||
|
||
// Get the processed text from Gemini response
|
||
processedText := ""
|
||
if len(resp.Candidates) > 0 && len(resp.Candidates[0].Content.Parts) > 0 {
|
||
if textPart, ok := resp.Candidates[0].Content.Parts[0].(genai.Text); ok {
|
||
processedText = string(textPart)
|
||
}
|
||
}
|
||
|
||
c.JSON(http.StatusOK, OCRResponse{
|
||
Success: true,
|
||
OriginalText: ocrText,
|
||
Result: processedText,
|
||
})
|
||
} |