tencent_ocr/.history/pkg/handler/upload_20250115142606.go
2025-01-15 16:01:18 +08:00

162 lines
4.3 KiB
Go

// 上传文件到cloudflare R2
package handler
import (
"bytes"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"encoding/base64"
"io"
"strings"
"your-project/pkg/service"
)
type UploadHandler struct {
accessKey string
secretKey string
bucket string
endpoint string
customDomain string
ocrService *OCRService
geminiService *service.GeminiService
}
type MultiUploadResponse struct {
ImageURLs []string `json:"image_urls"`
Text string `json:"text"`
Success bool `json:"success"`
}
func (h *UploadHandler) HandleMultiUpload(c *gin.Context) {
form, err := c.MultipartForm()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "Failed to parse form"})
return
}
files := form.File["files"]
if len(files) == 0 {
c.JSON(http.StatusBadRequest, gin.H{"error": "No files uploaded"})
return
}
if len(files) > 5 {
c.JSON(http.StatusBadRequest, gin.H{"error": "Maximum 5 files allowed"})
return
}
var imageURLs []string
var ocrTexts []string
for _, fileHeader := range files {
if fileHeader.Size > 10<<20 { // 10MB
c.JSON(http.StatusBadRequest, gin.H{"error": "File size exceeds the limit of 10MB"})
return
}
file, err := fileHeader.Open()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to open file"})
return
}
defer file.Close()
// Read file content
fileBytes, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to read file"})
return
}
// Verify file type
contentType := http.DetectContentType(fileBytes)
if !isImage(contentType) {
c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid file type. Only images are allowed"})
return
}
// Convert to base64
base64Str := base64.StdEncoding.EncodeToString(fileBytes)
// Process OCR
ocrText, err := h.ocrService.ProcessImage(c.Request.Context(), base64Str)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "OCR processing failed"})
return
}
ocrTexts = append(ocrTexts, ocrText)
// Upload to R2
imageURL, err := h.uploadToR2(fileBytes, fileHeader.Filename, contentType)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Failed to upload file"})
return
}
imageURLs = append(imageURLs, imageURL)
}
// Process combined text with Gemini if multiple images
finalText := strings.Join(ocrTexts, "\n")
if len(ocrTexts) > 1 {
prompt := "请将以下多段文字重新组织成一段通顺的文字,保持原意的同时确保语法和逻辑正确:\n\n" + finalText
processedText, err := h.geminiService.ProcessText(c.Request.Context(), prompt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "Text processing failed"})
return
}
finalText = processedText
}
c.JSON(http.StatusOK, MultiUploadResponse{
ImageURLs: imageURLs,
Text: finalText,
Success: true,
})
}
// uploadToR2 上传文件到Cloudflare R2
func (h *UploadHandler) uploadToR2(file []byte, fileName, contentType string) (string, error) {
// 创建S3会话
sess, err := session.NewSession(&aws.Config{
Endpoint: aws.String(h.endpoint),
Region: aws.String("auto"),
Credentials: credentials.NewStaticCredentials(h.accessKey, h.secretKey, ""),
})
if err != nil {
return "", fmt.Errorf("failed to create S3 session: %v", err)
}
// 创建S3服务客户端
svc := s3.New(sess)
// 上传文件到R2
_, err = svc.PutObject(&s3.PutObjectInput{
Bucket: aws.String(h.bucket),
Key: aws.String(fileName),
Body: bytes.NewReader(file),
ContentType: aws.String(contentType),
ACL: aws.String("public-read"), // 设置文件为公开可读
})
if err != nil {
return "", fmt.Errorf("failed to upload file to R2: %v", err)
}
// 生成文件的URL
imageURL := fmt.Sprintf("https://%s/%s", h.customDomain, fileName)
return imageURL, nil
}
// isImage 检查文件是否是图片
func isImage(contentType string) bool {
allowedTypes := []string{"image/jpeg", "image/png", "image/gif", "image/bmp", "image/tiff", "image/webp"}
for _, t := range allowedTypes {
if contentType == t {
return true
}
}
return false
}