# main.py from fastapi import FastAPI, UploadFile, File, HTTPException, Form from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from typing import List, Optional import uuid import os from datetime import datetime import requests from qdrant_client import QdrantClient from qdrant_client.http import models from fastapi.middleware.cors import CORSMiddleware # Database setup from sqlalchemy import create_engine, Column, String, DateTime, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker # Image processing from PIL import Image import io import torch from transformers import CLIPProcessor, CLIPModel import logging from qdrant_client.http.models import VectorParams, Distance # Настройка базового логирования (например, вывод в консоль) logging.basicConfig( level=logging.INFO, # Можно изменить уровень, например, на DEBUG format="%(asctime)s [%(levelname)s] %(name)s - %(message)s" ) logger = logging.getLogger(__name__) # Задайте имя коллекции COLLECTION_NAME = "posts" # Определите размер вектора. Этот размер должен соответствовать длине объединённого эмбеддинга текста и изображения. VECTOR_SIZE = 1280 # Пример: поменяйте на актуальное значение для вашего случая # Configuration DATABASE_URL = "sqlite:///./imageboard.db" QDRANT_URL = "http://localhost:6333" OLLAMA_URL = "http://localhost:11434" EMBEDDING_MODEL = "nomic-embed-text" # Локальная модель через Ollama IMAGE_MODEL = "openai/clip-vit-base-patch32" # Локальная CLIP модель IMAGE_SIZE = (224, 224) UPLOAD_DIR = "uploads" os.makedirs(UPLOAD_DIR, exist_ok=True) # Initialize components Base = declarative_base() engine = create_engine(DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) #Base.metadata.drop_all(bind=engine) # Удаляет все таблицы #Base.metadata.create_all(bind=engine) # Создаёт таблицы заново # Инициализация CLIP для изображений clip_model = CLIPModel.from_pretrained(IMAGE_MODEL) clip_processor = CLIPProcessor.from_pretrained(IMAGE_MODEL) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") clip_model = clip_model.to(device) # Qdrant клиент qdrant_client = QdrantClient(QDRANT_URL) def ensure_collection_exists(): try: # Попытка получить коллекцию. Если коллекция не существует, Qdrant выбросит исключение. qdrant_client.get_collection(collection_name=COLLECTION_NAME) logger.info("Коллекция '%s' существует.", COLLECTION_NAME) except Exception as e: logger.info("Коллекция '%s' не найдена. Создаём коллекцию...", COLLECTION_NAME) qdrant_client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams( size=VECTOR_SIZE, distance=Distance.COSINE # Или другой подходящий тип расстояния ) ) logger.info("Коллекция '%s' создана.", COLLECTION_NAME) # Вызываем функцию при инициализации приложения, например, в начале main.py ensure_collection_exists() app = FastAPI() app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name=UPLOAD_DIR) app.add_middleware( CORSMiddleware, allow_origins=["*"], # Разрешить все источники allow_credentials=True, allow_methods=["*"], # Разрешить все методы allow_headers=["*"], # Разрешить все заголовки ) # Database models class Post(Base): __tablename__ = "posts" id = Column(String, primary_key=True, index=True) text = Column(Text, nullable=True) image = Column(String, nullable=True) created_at = Column(DateTime) Base.metadata.create_all(bind=engine) # Pydantic model для ответа class PostResponse(BaseModel): id: str text: Optional[str] = None image: Optional[str] = None created_at: datetime vector: Optional[List[float]] = None class Config: orm_mode = True # Pydantic модель для запроса вектора class VectorQuery(BaseModel): vector: List[float] # Utility functions def generate_text_embedding(text: str) -> List[float]: response = requests.post( f"{OLLAMA_URL}/api/embeddings", json={"model": EMBEDDING_MODEL, "prompt": text} ) if response.status_code != 200: raise HTTPException(status_code=500, detail="Embedding generation failed") return response.json()["embedding"] def generate_image_embedding(image_bytes: bytes) -> List[float]: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") inputs = clip_processor( images=image, return_tensors="pt", padding=True ).to(device) with torch.no_grad(): features = clip_model.get_image_features(**inputs) return features.cpu().numpy().tolist()[0] def process_image(image_bytes: bytes) -> bytes: img = Image.open(io.BytesIO(image_bytes)) img = img.convert("RGB") img = img.resize(IMAGE_SIZE) buffer = io.BytesIO() img.save(buffer, format="JPEG") return buffer.getvalue() # API endpoints # API endpoints @app.post("/posts/", response_model=PostResponse) async def create_post( text: Optional[str] = Form(None), image: Optional[UploadFile] = File(None) ): db = SessionLocal() try: post_id = str(uuid.uuid4()) image_path = None thumbnail_path = None # Placeholder for thumbnail embeddings = [] if text: logger.info("Генерация эмбеддинга для текста") text_embedding = generate_text_embedding(text) embeddings.extend(text_embedding) if image: logger.info("Обработка изображения") image_bytes = await image.read() # Save original image image_path = f"{UPLOAD_DIR}/{post_id}.jpg" with open(image_path, "wb") as f: f.write(image_bytes) # Create processed image for embeddings processed_image = process_image(image_bytes) # Assume this resizes/image processing # Generate thumbnail as placeholder (example implementation) #thumbnail = generate_thumbnail(image_bytes) # Implement your thumbnail generation #thumbnail_path = f"{UPLOAD_DIR}/{post_id}_thumbnail.jpg" #with open(thumbnail_path, "wb") as f: # f.write(thumbnail) # Generate embeddings from processed image image_embedding = generate_image_embedding(processed_image) embeddings.extend(image_embedding) logger.info("Сохранение данных в Qdrant") qdrant_client.upsert( collection_name="posts", points=[models.PointStruct( id=post_id, vector=embeddings, payload={"post_id": post_id} )] ) logger.info("Сохранение поста в базу данных") db_post = Post( id=post_id, text=text, image=image_path, #thumbnail=thumbnail_path, Add thumbnail field to your Post model created_at=datetime.now() ) db.add(db_post) db.commit() db.refresh(db_post) response = PostResponse( id=db_post.id, text=db_post.text, image=db_post.image, #thumbnail=db_post.thumbnail, Update response model to include thumbnail created_at=db_post.created_at, vector=embeddings ) logger.info("Пост успешно создан: %s", response) return response except Exception as e: db.rollback() logger.exception("Ошибка при создании поста") raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @app.get("/search/") async def search_posts( text: Optional[str] = None, image: Optional[UploadFile] = File(None) ): try: query_embedding = [] if text: logger.info("Генерация эмбеддинга для текста (поиск)") text_embedding = generate_text_embedding(text) query_embedding.extend(text_embedding) if image: logger.info("Генерация эмбеддинга для изображения (поиск)") image_bytes = await image.read() processed_image = process_image(image_bytes) image_embedding = generate_image_embedding(processed_image) query_embedding.extend(image_embedding) logger.info("Выполнение поиска в Qdrant") search_results = qdrant_client.search( collection_name="posts", query_vector=query_embedding, limit=10 ) logger.info("Поиск завершён. Найдено результатов: %d", len(search_results)) return [result.payload for result in search_results] except Exception as e: logger.exception("Ошибка при поиске постов") raise HTTPException(status_code=500, detail=str(e)) @app.get("/posts/", response_model=List[PostResponse]) async def get_all_posts(): db = SessionLocal() try: posts = db.query(Post).all() return posts finally: db.close() # Новый endpoint: получение "древа" постов по вектору пользователя @app.post("/posts/tree", response_model=List[PostResponse]) async def get_posts_tree(query: VectorQuery): # Выполняем поиск в Qdrant с большим лимитом, чтобы получить все посты, отсортированные по сходству search_results = qdrant_client.search( collection_name="posts", query_vector=query.vector, limit=10000 # Задайте лимит в зависимости от ожидаемого числа постов ) print("search_results") print(search_results) # Извлекаем список ID постов в том порядке, в котором Qdrant вернул результаты (от ближайших к дальним) post_ids = [result.payload.get("post_id") for result in search_results] print("post_ids") print(post_ids) db = SessionLocal() try: # Получаем все посты из БД по списку ID posts = db.query(Post).filter(Post.id.in_(post_ids)).all() print("posts") print(posts) # Создаём словарь для сохранения соответствия post_id -> post posts_dict = {post.id: post for post in posts} print("posts_dict") print(posts_dict) # Восстанавливаем порядок, используя список post_ids ordered_posts = [posts_dict[pid] for pid in post_ids if pid in posts_dict] print("ordered_posts") print(ordered_posts) return ordered_posts finally: db.close() if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)