2025-02-10 15:53:52 +03:00
|
|
|
|
# main.py
|
|
|
|
|
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
|
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
import uuid
|
|
|
|
|
import os
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
import requests
|
|
|
|
|
from qdrant_client import QdrantClient
|
|
|
|
|
from qdrant_client.http import models
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2025-02-15 13:52:05 +03:00
|
|
|
|
import time # Добавлен импорт модуля time
|
2025-02-10 15:53:52 +03:00
|
|
|
|
|
|
|
|
|
# Database setup
|
|
|
|
|
from sqlalchemy import create_engine, Column, String, DateTime, Text
|
|
|
|
|
from sqlalchemy.ext.declarative import declarative_base
|
|
|
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
|
|
|
|
|
|
# Image processing
|
|
|
|
|
from PIL import Image
|
|
|
|
|
import io
|
|
|
|
|
import torch
|
|
|
|
|
from transformers import CLIPProcessor, CLIPModel
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
|
|
|
|
|
from qdrant_client.http.models import VectorParams, Distance
|
|
|
|
|
|
2025-02-15 13:52:05 +03:00
|
|
|
|
# Настройка базового логирования
|
2025-02-10 15:53:52 +03:00
|
|
|
|
logging.basicConfig(
|
2025-02-15 13:52:05 +03:00
|
|
|
|
level=logging.INFO,
|
2025-02-10 15:53:52 +03:00
|
|
|
|
format="%(asctime)s [%(levelname)s] %(name)s - %(message)s"
|
|
|
|
|
)
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
2025-02-15 13:52:05 +03:00
|
|
|
|
# Конфигурация
|
2025-02-10 15:53:52 +03:00
|
|
|
|
COLLECTION_NAME = "posts"
|
2025-02-15 13:52:05 +03:00
|
|
|
|
VECTOR_SIZE = 1280
|
2025-02-10 15:53:52 +03:00
|
|
|
|
DATABASE_URL = "sqlite:///./imageboard.db"
|
|
|
|
|
QDRANT_URL = "http://localhost:6333"
|
|
|
|
|
OLLAMA_URL = "http://localhost:11434"
|
2025-02-15 13:52:05 +03:00
|
|
|
|
EMBEDDING_MODEL = "nomic-embed-text"
|
|
|
|
|
IMAGE_MODEL = "openai/clip-vit-base-patch32"
|
2025-02-10 15:53:52 +03:00
|
|
|
|
IMAGE_SIZE = (224, 224)
|
|
|
|
|
UPLOAD_DIR = "uploads"
|
|
|
|
|
|
|
|
|
|
os.makedirs(UPLOAD_DIR, exist_ok=True)
|
|
|
|
|
|
2025-02-15 13:52:05 +03:00
|
|
|
|
# Инициализация компонентов
|
2025-02-10 15:53:52 +03:00
|
|
|
|
Base = declarative_base()
|
|
|
|
|
engine = create_engine(DATABASE_URL)
|
|
|
|
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
|
|
|
|
|
2025-02-15 13:52:05 +03:00
|
|
|
|
# Инициализация CLIP
|
2025-02-10 15:53:52 +03:00
|
|
|
|
clip_model = CLIPModel.from_pretrained(IMAGE_MODEL)
|
|
|
|
|
clip_processor = CLIPProcessor.from_pretrained(IMAGE_MODEL)
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
clip_model = clip_model.to(device)
|
|
|
|
|
|
2025-02-15 13:52:05 +03:00
|
|
|
|
# Функция для создания QdrantClient с повторными попытками
|
|
|
|
|
def create_qdrant_client():
|
|
|
|
|
max_attempts = 5
|
|
|
|
|
attempt = 0
|
|
|
|
|
while attempt < max_attempts:
|
|
|
|
|
try:
|
|
|
|
|
client = QdrantClient(QDRANT_URL)
|
|
|
|
|
# Проверка подключения
|
|
|
|
|
client.get_collections()
|
|
|
|
|
logger.info("Успешное подключение к Qdrant")
|
|
|
|
|
return client
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.warning(f"Попытка {attempt+1} подключения к Qdrant не удалась: {str(e)}")
|
|
|
|
|
attempt += 1
|
|
|
|
|
time.sleep(2)
|
|
|
|
|
raise RuntimeError(f"Не удалось подключиться к Qdrant после {max_attempts} попыток")
|
|
|
|
|
|
|
|
|
|
# Инициализация клиента Qdrant
|
|
|
|
|
try:
|
|
|
|
|
qdrant_client = create_qdrant_client()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Ошибка инициализации Qdrant: {str(e)}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
# Функция проверки и создания коллекции
|
2025-02-10 15:53:52 +03:00
|
|
|
|
def ensure_collection_exists():
|
2025-02-15 13:52:05 +03:00
|
|
|
|
max_attempts = 5
|
|
|
|
|
attempt = 0
|
|
|
|
|
while attempt < max_attempts:
|
|
|
|
|
try:
|
|
|
|
|
# Проверка существования коллекции
|
|
|
|
|
qdrant_client.get_collection(collection_name=COLLECTION_NAME)
|
|
|
|
|
logger.info(f"Коллекция '{COLLECTION_NAME}' существует")
|
|
|
|
|
return
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if "not found" in str(e).lower():
|
|
|
|
|
logger.info(f"Создание коллекции '{COLLECTION_NAME}'...")
|
|
|
|
|
try:
|
|
|
|
|
qdrant_client.create_collection(
|
|
|
|
|
collection_name=COLLECTION_NAME,
|
|
|
|
|
vectors_config=VectorParams(
|
|
|
|
|
size=VECTOR_SIZE,
|
|
|
|
|
distance=Distance.COSINE
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
logger.info(f"Коллекция '{COLLECTION_NAME}' создана")
|
|
|
|
|
return
|
|
|
|
|
except Exception as create_error:
|
|
|
|
|
logger.error(f"Ошибка создания: {str(create_error)}")
|
|
|
|
|
else:
|
|
|
|
|
logger.error(f"Ошибка подключения: {str(e)}")
|
|
|
|
|
|
|
|
|
|
attempt += 1
|
|
|
|
|
time.sleep(2)
|
|
|
|
|
raise RuntimeError(f"Не удалось инициализировать коллекцию после {max_attempts} попыток")
|
|
|
|
|
|
|
|
|
|
# Вызов функции проверки коллекции
|
|
|
|
|
try:
|
|
|
|
|
ensure_collection_exists()
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"Ошибка инициализации коллекции: {str(e)}")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
# Инициализация FastAPI
|
2025-02-10 15:53:52 +03:00
|
|
|
|
app = FastAPI()
|
2025-02-15 13:52:05 +03:00
|
|
|
|
app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads")
|
2025-02-10 15:53:52 +03:00
|
|
|
|
app.add_middleware(
|
|
|
|
|
CORSMiddleware,
|
2025-02-15 13:52:05 +03:00
|
|
|
|
allow_origins=["*"],
|
2025-02-10 15:53:52 +03:00
|
|
|
|
allow_credentials=True,
|
2025-02-15 13:52:05 +03:00
|
|
|
|
allow_methods=["*"],
|
|
|
|
|
allow_headers=["*"],
|
2025-02-10 15:53:52 +03:00
|
|
|
|
)
|
|
|
|
|
|
2025-02-15 13:52:05 +03:00
|
|
|
|
# Модель базы данных
|
2025-02-10 15:53:52 +03:00
|
|
|
|
class Post(Base):
|
|
|
|
|
__tablename__ = "posts"
|
|
|
|
|
id = Column(String, primary_key=True, index=True)
|
|
|
|
|
text = Column(Text, nullable=True)
|
|
|
|
|
image = Column(String, nullable=True)
|
|
|
|
|
created_at = Column(DateTime)
|
|
|
|
|
|
|
|
|
|
Base.metadata.create_all(bind=engine)
|
|
|
|
|
|
|
|
|
|
# Pydantic model для ответа
|
|
|
|
|
class PostResponse(BaseModel):
|
|
|
|
|
id: str
|
|
|
|
|
text: Optional[str] = None
|
|
|
|
|
image: Optional[str] = None
|
|
|
|
|
created_at: datetime
|
|
|
|
|
vector: Optional[List[float]] = None
|
|
|
|
|
|
|
|
|
|
class Config:
|
|
|
|
|
orm_mode = True
|
|
|
|
|
|
|
|
|
|
# Pydantic модель для запроса вектора
|
|
|
|
|
class VectorQuery(BaseModel):
|
|
|
|
|
vector: List[float]
|
|
|
|
|
|
|
|
|
|
# Utility functions
|
|
|
|
|
def generate_text_embedding(text: str) -> List[float]:
|
|
|
|
|
response = requests.post(
|
|
|
|
|
f"{OLLAMA_URL}/api/embeddings",
|
|
|
|
|
json={"model": EMBEDDING_MODEL, "prompt": text}
|
|
|
|
|
)
|
|
|
|
|
if response.status_code != 200:
|
|
|
|
|
raise HTTPException(status_code=500, detail="Embedding generation failed")
|
|
|
|
|
return response.json()["embedding"]
|
|
|
|
|
|
|
|
|
|
def generate_image_embedding(image_bytes: bytes) -> List[float]:
|
|
|
|
|
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
|
|
inputs = clip_processor(
|
|
|
|
|
images=image,
|
|
|
|
|
return_tensors="pt",
|
|
|
|
|
padding=True
|
|
|
|
|
).to(device)
|
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
features = clip_model.get_image_features(**inputs)
|
|
|
|
|
return features.cpu().numpy().tolist()[0]
|
|
|
|
|
|
|
|
|
|
def process_image(image_bytes: bytes) -> bytes:
|
|
|
|
|
img = Image.open(io.BytesIO(image_bytes))
|
|
|
|
|
img = img.convert("RGB")
|
|
|
|
|
img = img.resize(IMAGE_SIZE)
|
|
|
|
|
buffer = io.BytesIO()
|
|
|
|
|
img.save(buffer, format="JPEG")
|
|
|
|
|
return buffer.getvalue()
|
|
|
|
|
|
2025-02-12 00:57:29 +03:00
|
|
|
|
# API endpoints
|
2025-02-10 15:53:52 +03:00
|
|
|
|
# API endpoints
|
|
|
|
|
@app.post("/posts/", response_model=PostResponse)
|
|
|
|
|
async def create_post(
|
|
|
|
|
text: Optional[str] = Form(None),
|
|
|
|
|
image: Optional[UploadFile] = File(None)
|
|
|
|
|
):
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
post_id = str(uuid.uuid4())
|
|
|
|
|
image_path = None
|
2025-02-12 00:57:29 +03:00
|
|
|
|
thumbnail_path = None # Placeholder for thumbnail
|
2025-02-10 15:53:52 +03:00
|
|
|
|
embeddings = []
|
|
|
|
|
|
|
|
|
|
if text:
|
|
|
|
|
logger.info("Генерация эмбеддинга для текста")
|
|
|
|
|
text_embedding = generate_text_embedding(text)
|
|
|
|
|
embeddings.extend(text_embedding)
|
|
|
|
|
|
|
|
|
|
if image:
|
|
|
|
|
logger.info("Обработка изображения")
|
|
|
|
|
image_bytes = await image.read()
|
2025-02-12 00:57:29 +03:00
|
|
|
|
|
|
|
|
|
# Save original image
|
2025-02-10 15:53:52 +03:00
|
|
|
|
image_path = f"{UPLOAD_DIR}/{post_id}.jpg"
|
|
|
|
|
with open(image_path, "wb") as f:
|
2025-02-12 00:57:29 +03:00
|
|
|
|
f.write(image_bytes)
|
|
|
|
|
|
|
|
|
|
# Create processed image for embeddings
|
|
|
|
|
processed_image = process_image(image_bytes) # Assume this resizes/image processing
|
2025-02-10 15:53:52 +03:00
|
|
|
|
|
2025-02-12 00:57:29 +03:00
|
|
|
|
# Generate thumbnail as placeholder (example implementation)
|
|
|
|
|
#thumbnail = generate_thumbnail(image_bytes) # Implement your thumbnail generation
|
|
|
|
|
#thumbnail_path = f"{UPLOAD_DIR}/{post_id}_thumbnail.jpg"
|
|
|
|
|
#with open(thumbnail_path, "wb") as f:
|
|
|
|
|
# f.write(thumbnail)
|
|
|
|
|
|
|
|
|
|
# Generate embeddings from processed image
|
2025-02-10 15:53:52 +03:00
|
|
|
|
image_embedding = generate_image_embedding(processed_image)
|
|
|
|
|
embeddings.extend(image_embedding)
|
|
|
|
|
|
|
|
|
|
logger.info("Сохранение данных в Qdrant")
|
|
|
|
|
qdrant_client.upsert(
|
|
|
|
|
collection_name="posts",
|
|
|
|
|
points=[models.PointStruct(
|
|
|
|
|
id=post_id,
|
|
|
|
|
vector=embeddings,
|
|
|
|
|
payload={"post_id": post_id}
|
|
|
|
|
)]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info("Сохранение поста в базу данных")
|
|
|
|
|
db_post = Post(
|
|
|
|
|
id=post_id,
|
|
|
|
|
text=text,
|
|
|
|
|
image=image_path,
|
2025-02-12 00:57:29 +03:00
|
|
|
|
#thumbnail=thumbnail_path, Add thumbnail field to your Post model
|
2025-02-10 15:53:52 +03:00
|
|
|
|
created_at=datetime.now()
|
|
|
|
|
)
|
|
|
|
|
db.add(db_post)
|
|
|
|
|
db.commit()
|
|
|
|
|
db.refresh(db_post)
|
|
|
|
|
|
|
|
|
|
response = PostResponse(
|
|
|
|
|
id=db_post.id,
|
|
|
|
|
text=db_post.text,
|
|
|
|
|
image=db_post.image,
|
2025-02-12 00:57:29 +03:00
|
|
|
|
#thumbnail=db_post.thumbnail, Update response model to include thumbnail
|
2025-02-10 15:53:52 +03:00
|
|
|
|
created_at=db_post.created_at,
|
|
|
|
|
vector=embeddings
|
|
|
|
|
)
|
|
|
|
|
logger.info("Пост успешно создан: %s", response)
|
|
|
|
|
return response
|
|
|
|
|
except Exception as e:
|
|
|
|
|
db.rollback()
|
|
|
|
|
logger.exception("Ошибка при создании поста")
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
2025-02-12 00:57:29 +03:00
|
|
|
|
|
2025-02-10 15:53:52 +03:00
|
|
|
|
@app.get("/search/")
|
|
|
|
|
async def search_posts(
|
|
|
|
|
text: Optional[str] = None,
|
|
|
|
|
image: Optional[UploadFile] = File(None)
|
|
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
query_embedding = []
|
|
|
|
|
|
|
|
|
|
if text:
|
|
|
|
|
logger.info("Генерация эмбеддинга для текста (поиск)")
|
|
|
|
|
text_embedding = generate_text_embedding(text)
|
|
|
|
|
query_embedding.extend(text_embedding)
|
|
|
|
|
|
|
|
|
|
if image:
|
|
|
|
|
logger.info("Генерация эмбеддинга для изображения (поиск)")
|
|
|
|
|
image_bytes = await image.read()
|
|
|
|
|
processed_image = process_image(image_bytes)
|
|
|
|
|
image_embedding = generate_image_embedding(processed_image)
|
|
|
|
|
query_embedding.extend(image_embedding)
|
|
|
|
|
|
|
|
|
|
logger.info("Выполнение поиска в Qdrant")
|
|
|
|
|
search_results = qdrant_client.search(
|
|
|
|
|
collection_name="posts",
|
|
|
|
|
query_vector=query_embedding,
|
|
|
|
|
limit=10
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
logger.info("Поиск завершён. Найдено результатов: %d", len(search_results))
|
|
|
|
|
return [result.payload for result in search_results]
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception("Ошибка при поиске постов")
|
|
|
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
|
|
|
|
|
|
@app.get("/posts/", response_model=List[PostResponse])
|
|
|
|
|
async def get_all_posts():
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
posts = db.query(Post).all()
|
|
|
|
|
return posts
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
# Новый endpoint: получение "древа" постов по вектору пользователя
|
|
|
|
|
@app.post("/posts/tree", response_model=List[PostResponse])
|
|
|
|
|
async def get_posts_tree(query: VectorQuery):
|
|
|
|
|
# Выполняем поиск в Qdrant с большим лимитом, чтобы получить все посты, отсортированные по сходству
|
|
|
|
|
search_results = qdrant_client.search(
|
|
|
|
|
collection_name="posts",
|
|
|
|
|
query_vector=query.vector,
|
|
|
|
|
limit=10000 # Задайте лимит в зависимости от ожидаемого числа постов
|
|
|
|
|
)
|
|
|
|
|
print("search_results")
|
|
|
|
|
print(search_results)
|
|
|
|
|
# Извлекаем список ID постов в том порядке, в котором Qdrant вернул результаты (от ближайших к дальним)
|
|
|
|
|
post_ids = [result.payload.get("post_id") for result in search_results]
|
|
|
|
|
print("post_ids")
|
|
|
|
|
print(post_ids)
|
|
|
|
|
db = SessionLocal()
|
|
|
|
|
try:
|
|
|
|
|
# Получаем все посты из БД по списку ID
|
|
|
|
|
posts = db.query(Post).filter(Post.id.in_(post_ids)).all()
|
|
|
|
|
print("posts")
|
|
|
|
|
print(posts)
|
|
|
|
|
# Создаём словарь для сохранения соответствия post_id -> post
|
|
|
|
|
posts_dict = {post.id: post for post in posts}
|
|
|
|
|
print("posts_dict")
|
|
|
|
|
print(posts_dict)
|
|
|
|
|
# Восстанавливаем порядок, используя список post_ids
|
|
|
|
|
ordered_posts = [posts_dict[pid] for pid in post_ids if pid in posts_dict]
|
|
|
|
|
print("ordered_posts")
|
|
|
|
|
print(ordered_posts)
|
|
|
|
|
return ordered_posts
|
|
|
|
|
finally:
|
|
|
|
|
db.close()
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
import uvicorn
|
|
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|