# main.py from fastapi import FastAPI, UploadFile, File, HTTPException, Form from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from typing import List, Optional import uuid import os from datetime import datetime import requests from qdrant_client import QdrantClient from qdrant_client.http import models from fastapi.middleware.cors import CORSMiddleware import time # Добавлен импорт модуля time # Database setup from sqlalchemy import create_engine, Column, String, DateTime, Text from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker # Image processing from PIL import Image import io import torch from transformers import CLIPProcessor, CLIPModel import logging from qdrant_client.http.models import VectorParams, Distance import os from dotenv import load_dotenv # Настройка базового логирования logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s - %(message)s" ) logger = logging.getLogger(__name__) # Загрузить переменные окружения из файла .env load_dotenv() # Получение значений с указанием значений по умолчанию COLLECTION_NAME = os.getenv("COLLECTION_NAME", "posts") VECTOR_SIZE = int(os.getenv("VECTOR_SIZE", 1280)) DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./imageboard.db") QDRANT_URL = os.getenv("QDRANT_URL", "http://qdrant:6333") OLLAMA_URL = os.getenv("OLLAMA_URL", "http://ollama:11435") EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL", "nomic-embed-text") IMAGE_MODEL = os.getenv("IMAGE_MODEL", "openai/clip-vit-base-patch32") # IMAGE_SIZE ожидается в формате "224,224", преобразуем его в кортеж чисел image_size_str = os.getenv("IMAGE_SIZE", "224,224") IMAGE_SIZE = tuple(map(int, image_size_str.split(','))) UPLOAD_DIR = os.getenv("UPLOAD_DIR", "uploads") os.makedirs(UPLOAD_DIR, exist_ok=True) # Инициализация компонентов Base = declarative_base() engine = create_engine(DATABASE_URL) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) # Инициализация CLIP clip_model = CLIPModel.from_pretrained(IMAGE_MODEL) clip_processor = CLIPProcessor.from_pretrained(IMAGE_MODEL) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") clip_model = clip_model.to(device) # Функция для создания QdrantClient с повторными попытками def create_qdrant_client(): max_attempts = 5 attempt = 0 while attempt < max_attempts: try: client = QdrantClient(QDRANT_URL) # Проверка подключения client.get_collections() logger.info("Успешное подключение к Qdrant") return client except Exception as e: logger.warning(f"Попытка {attempt+1} подключения к Qdrant не удалась: {str(e)}") attempt += 1 time.sleep(2) raise RuntimeError(f"Не удалось подключиться к Qdrant после {max_attempts} попыток") # Инициализация клиента Qdrant try: qdrant_client = create_qdrant_client() except Exception as e: logger.error(f"Ошибка инициализации Qdrant: {str(e)}") raise # Функция проверки и создания коллекции def ensure_collection_exists(): max_attempts = 5 attempt = 0 while attempt < max_attempts: try: # Проверка существования коллекции qdrant_client.get_collection(collection_name=COLLECTION_NAME) logger.info(f"Коллекция '{COLLECTION_NAME}' существует") return except Exception as e: if "not found" in str(e).lower(): logger.info(f"Создание коллекции '{COLLECTION_NAME}'...") try: qdrant_client.create_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams( size=VECTOR_SIZE, distance=Distance.COSINE ) ) logger.info(f"Коллекция '{COLLECTION_NAME}' создана") return except Exception as create_error: logger.error(f"Ошибка создания: {str(create_error)}") else: logger.error(f"Ошибка подключения: {str(e)}") attempt += 1 time.sleep(2) raise RuntimeError(f"Не удалось инициализировать коллекцию после {max_attempts} попыток") # Вызов функции проверки коллекции try: ensure_collection_exists() except Exception as e: logger.error(f"Ошибка инициализации коллекции: {str(e)}") raise # Инициализация FastAPI app = FastAPI() app.mount("/uploads", StaticFiles(directory=UPLOAD_DIR), name="uploads") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Модель базы данных class Post(Base): __tablename__ = "posts" id = Column(String, primary_key=True, index=True) text = Column(Text, nullable=True) image = Column(String, nullable=True) created_at = Column(DateTime) Base.metadata.create_all(bind=engine) # Pydantic model для ответа class PostResponse(BaseModel): id: str text: Optional[str] = None image: Optional[str] = None created_at: datetime vector: Optional[List[float]] = None class Config: orm_mode = True # Pydantic модель для запроса вектора class VectorQuery(BaseModel): vector: List[float] # Utility functions def generate_text_embedding(text: str) -> List[float]: response = requests.post( f"{OLLAMA_URL}/api/embeddings", json={"model": EMBEDDING_MODEL, "prompt": text} ) if response.status_code != 200: raise HTTPException(status_code=500, detail="Embedding generation failed") return response.json()["embedding"] def generate_image_embedding(image_bytes: bytes) -> List[float]: image = Image.open(io.BytesIO(image_bytes)).convert("RGB") inputs = clip_processor( images=image, return_tensors="pt", padding=True ).to(device) with torch.no_grad(): features = clip_model.get_image_features(**inputs) return features.cpu().numpy().tolist()[0] def process_image(image_bytes: bytes) -> bytes: img = Image.open(io.BytesIO(image_bytes)) img = img.convert("RGB") img = img.resize(IMAGE_SIZE) buffer = io.BytesIO() img.save(buffer, format="JPEG") return buffer.getvalue() # API endpoints # API endpoints @app.post("/posts/", response_model=PostResponse) async def create_post( text: Optional[str] = Form(None), image: Optional[UploadFile] = File(None) ): db = SessionLocal() try: post_id = str(uuid.uuid4()) image_path = None thumbnail_path = None # Placeholder for thumbnail embeddings = [] if text: logger.info("Генерация эмбеддинга для текста") text_embedding = generate_text_embedding(text) embeddings.extend(text_embedding) if image: logger.info("Обработка изображения") image_bytes = await image.read() # Save original image image_path = f"{UPLOAD_DIR}/{post_id}.jpg" with open(image_path, "wb") as f: f.write(image_bytes) # Create processed image for embeddings processed_image = process_image(image_bytes) # Assume this resizes/image processing # Generate thumbnail as placeholder (example implementation) #thumbnail = generate_thumbnail(image_bytes) # Implement your thumbnail generation #thumbnail_path = f"{UPLOAD_DIR}/{post_id}_thumbnail.jpg" #with open(thumbnail_path, "wb") as f: # f.write(thumbnail) # Generate embeddings from processed image image_embedding = generate_image_embedding(processed_image) embeddings.extend(image_embedding) logger.info("Сохранение данных в Qdrant") qdrant_client.upsert( collection_name="posts", points=[models.PointStruct( id=post_id, vector=embeddings, payload={"post_id": post_id} )] ) logger.info("Сохранение поста в базу данных") db_post = Post( id=post_id, text=text, image=image_path, #thumbnail=thumbnail_path, Add thumbnail field to your Post model created_at=datetime.now() ) db.add(db_post) db.commit() db.refresh(db_post) response = PostResponse( id=db_post.id, text=db_post.text, image=db_post.image, #thumbnail=db_post.thumbnail, Update response model to include thumbnail created_at=db_post.created_at, vector=embeddings ) logger.info("Пост успешно создан: %s", response) return response except Exception as e: db.rollback() logger.exception("Ошибка при создании поста") raise HTTPException(status_code=500, detail=str(e)) finally: db.close() @app.get("/search/") async def search_posts( text: Optional[str] = None, image: Optional[UploadFile] = File(None) ): try: query_embedding = [] if text: logger.info("Генерация эмбеддинга для текста (поиск)") text_embedding = generate_text_embedding(text) query_embedding.extend(text_embedding) if image: logger.info("Генерация эмбеддинга для изображения (поиск)") image_bytes = await image.read() processed_image = process_image(image_bytes) image_embedding = generate_image_embedding(processed_image) query_embedding.extend(image_embedding) logger.info("Выполнение поиска в Qdrant") search_results = qdrant_client.search( collection_name="posts", query_vector=query_embedding, limit=10 ) logger.info("Поиск завершён. Найдено результатов: %d", len(search_results)) return [result.payload for result in search_results] except Exception as e: logger.exception("Ошибка при поиске постов") raise HTTPException(status_code=500, detail=str(e)) @app.get("/posts/", response_model=List[PostResponse]) async def get_all_posts(): db = SessionLocal() try: posts = db.query(Post).all() return posts finally: db.close() # Новый endpoint: получение "древа" постов по вектору пользователя @app.post("/posts/tree", response_model=List[PostResponse]) async def get_posts_tree(query: VectorQuery): # Выполняем поиск в Qdrant с большим лимитом, чтобы получить все посты, отсортированные по сходству search_results = qdrant_client.search( collection_name="posts", query_vector=query.vector, limit=10000 # Задайте лимит в зависимости от ожидаемого числа постов ) print("search_results") print(search_results) # Извлекаем список ID постов в том порядке, в котором Qdrant вернул результаты (от ближайших к дальним) post_ids = [result.payload.get("post_id") for result in search_results] print("post_ids") print(post_ids) db = SessionLocal() try: # Получаем все посты из БД по списку ID posts = db.query(Post).filter(Post.id.in_(post_ids)).all() print("posts") print(posts) # Создаём словарь для сохранения соответствия post_id -> post posts_dict = {post.id: post for post in posts} print("posts_dict") print(posts_dict) # Восстанавливаем порядок, используя список post_ids ordered_posts = [posts_dict[pid] for pid in post_ids if pid in posts_dict] print("ordered_posts") print(ordered_posts) return ordered_posts finally: db.close() if __name__ == "__main__": import uvicorn print("COLLECTION_NAME:", COLLECTION_NAME) print("VECTOR_SIZE:", VECTOR_SIZE) print("DATABASE_URL:", DATABASE_URL) print("QDRANT_URL:", QDRANT_URL) print("OLLAMA_URL:", OLLAMA_URL) print("EMBEDDING_MODEL:", EMBEDDING_MODEL) print("IMAGE_MODEL:", IMAGE_MODEL) print("IMAGE_SIZE:", IMAGE_SIZE) print("UPLOAD_DIR:", UPLOAD_DIR) uvicorn.run(app, host="0.0.0.0", port=8000)